当前位置: 代码网 > 科技>人工智能>智能机器人 > 【自然语言处理】【大模型】DeepSeek-V2论文解析

【自然语言处理】【大模型】DeepSeek-V2论文解析

2024年07月31日 智能机器人 我要评论
论文地址:https://arxiv.org/pdf/2405.04434​传统Transformer采用MHA(Multi-Head Attention),但是kv cache会成为推理瓶颈。MQA(Multi-Query Attention)和GQA(Grouped-Query Attention)可以一定程度减少kv cache,但效果上不如MHA。DeepSeek-V2设计了一种称为MLA(Multi-Head Latent Attention)的注意力机制。MLA通过低秩key-value联合压

论文地址:https://arxiv.org/pdf/2405.04434

一、简介

  • deepseek-v2是一个总参数为236b的moe模型,每个token仅激活21b的参数,并支持128k的上下文长度。
  • 提出了multi-head latent attention(mla),通过压缩kv cache至隐向量,从而保证高效推理。
  • 相比于deepseek 67b,deepseek-v2实现了更好的表现,节约了42.5%的训练成本,降低了93.3%的kv cache,提升最大吞吐5.76倍。
  • 预训练语料包含了8.1t tokens并进一步进行sft和rl。

二、模型结构

1. mla(multi-head latent attention)

​ 传统transformer采用mha(multi-head attention),但是kv cache会成为推理瓶颈。mqa(multi-query attention)和gqa(grouped-query attention)可以一定程度减少kv cache,但效果上不如mha。deepseek-v2设计了一种称为mla(multi-head latent attention)的注意力机制。mla通过低秩key-value联合压缩,实现了比mha更好的效果并且需要的kv cache要小很多。

1.1 标准mha

​ 令 d d d为embedding维度, n h n_h nh是注意力头的数量, d h d_h dh是每个头的维度, h t ∈ r d \textbf{h}_t\in\mathbb{r}^d htrd是注意力层中第 t t t个token的输入。标准mha通过三个矩阵 w q , w k , w v ∈ r d h n h × d w^q,w^k,w^v\in\mathbb{r}^{d_h n_h\times d} wq,wk,wvrdhnh×d来产生 q t , k t , v t ∈ r d h n h \textbf{q}_t,\textbf{k}_t,\textbf{v}_t\in\mathbb{r}^{d_h n_h} qt,kt,vtrdhnh
q t = w q h t k t = w k h t v t = w v h t \begin{align} \textbf{q}_t&=w^q\textbf{h}_t \tag{1}\\ \textbf{k}_t&=w^k\textbf{h}_t \tag{2}\\ \textbf{v}_t&=w^v\textbf{h}_t \tag{3}\\ \end{align} \\ qtktvt=wqht=wkht=wvht(1)(2)(3)
在mha中 q t , k t , v t \textbf{q}_t,\textbf{k}_t,\textbf{v}_t qt,kt,vt会被划分为 n h n_h nh个头:
[ q t , 1 ; q t , 2 ; … , q t , n h ] = q t [ k t , 1 ; k t , 2 ; … , k t , n h ] = k t [ v t , 1 ; v t , 2 ; … , v t , n h ] = v t o t , i = ∑ j = 1 t softmax ( q t , i ⊤ k j , i d h ) v j , i u t = w o [ o t , 1 ; o t , 2 ; … , o t , n h ] \begin{align} &[\textbf{q}_{t,1};\textbf{q}_{t,2};\dots,\textbf{q}_{t,n_h}]=\textbf{q}_t \tag{4}\\ &[\textbf{k}_{t,1};\textbf{k}_{t,2};\dots,\textbf{k}_{t,n_h}]=\textbf{k}_t \tag{5}\\ &[\textbf{v}_{t,1};\textbf{v}_{t,2};\dots,\textbf{v}_{t,n_h}]=\textbf{v}_t \tag{6}\\ &\textbf{o}_{t,i}=\sum_{j=1}^t\text{softmax}(\frac{\textbf{q}_{t,i}^\top\textbf{k}_{j,i}}{\sqrt{d_h}})\textbf{v}_{j,i} \tag{7}\\ &\textbf{u}_t=w^o[\textbf{o}_{t,1};\textbf{o}_{t,2};\dots,\textbf{o}_{t,n_h}] \tag{8}\\ \end{align} \\ [qt,1;qt,2;,qt,nh]=qt[kt,1;kt,2;,kt,nh]=kt[vt,1;vt,2;,vt,nh]=vtot,i=j=1tsoftmax(dh qt,ikj,i)vj,iut=wo[ot,1;ot,2;,ot,nh](4)(5)(6)(7)(8)
其中 q t , i , k t , i , v t , i ∈ r d h \textbf{q}_{t,i},\textbf{k}_{t,i},\textbf{v}_{t,i}\in\mathbb{r}^{d_h} qt,i,kt,i,vt,irdh是第 i i i个注意力头的query、key和value, w o ∈ r d × d h n h w^o\in\mathbb{r}^{d\times d_h n_h} word×dhnh是输出投影矩阵。在推理时,所有的key和value都会被缓存来加速推理。对于每个token,mha需要缓存 2 n h d h l 2n_h d_h l 2nhdhl个元素。

1.2 低秩key-value联合压缩

在这里插入图片描述

​ mla通过低秩联合压缩key和value来减少kv cache:
c t k v = w d k v h t k t c = w u k c t k v v t c = w u v c t k v \begin{align} \textbf{c}_t^{kv}&=w^{dkv}\textbf{h}_t \tag{9}\\ \textbf{k}_t^c&=w^{uk}\textbf{c}_t^{kv} \tag{10}\\ \textbf{v}_t^c&=w^{uv}\textbf{c}_t^{kv} \tag{11}\\ \end{align} \\ ctkvktcvtc=wdkvht=wukctkv=wuvctkv(9)(10)(11)
其中 c t k v ∈ r d c \textbf{c}_t^{kv}\in\mathbb{r}^{d_c} ctkvrdc是用于压缩key和value的隐向量; d c ( ≪ d h n h ) d_c(\ll d_h n_h) dc(dhnh)表示kv压缩的维度; w d k v ∈ r d c × d w^{dkv}\in\mathbb{r}^{d_c\times d} wdkvrdc×d是下投影矩阵, w u k , w u v ∈ r d h n h × d c w^{uk},w^{uv}\in\mathbb{r}^{d_h n_h\times d_c} wuk,wuvrdhnh×dc表示上投影矩阵。在推理时,mla仅需要缓存 c t k v \textbf{c}_t^{kv} ctkv,因此kv cache仅需要缓存 d c l d_c l dcl个元素。此外,在推理时可以把 w u k w^{uk} wuk吸收到 w q w^q wq w u v w^{uv} wuv吸收到 w o w^o wo中,这样甚至都不需要计算key和value。

​ 此外,为了在训练时降低激活的显存占用,对query也进行低秩压缩,即使其不能降低kv cache。具体来说,
c t q = w d q h t q t c = w u q c t q \begin{align} \textbf{c}_t^q&=w^{dq}\textbf{h}_t \tag{12}\\ \textbf{q}_t^c&=w^{uq}\textbf{c}_t^q \tag{13} \\ \end{align} \\ ctqqtc=wdqht=wuqctq(12)(13)
其中 c t q ∈ r d c ′ \textbf{c}_t^q\in\mathbb{r}^{d_c'} ctqrdc是query的压缩后隐向量; d c ′ ( ≪ d h n h ) d_c'(\ll d_h n_h) dc(dhnh)表示query的压缩维度; w d q ∈ r d c ′ × d , w u q ∈ r d h n h × d c ′ w^{dq}\in\mathbb{r}^{d_c'\times d},w^{uq}\in\mathbb{r}^{d_h n_h\times d_c'} wdqrdc×d,wuqrdhnh×dc是下投影矩阵和上投影矩阵。

1.3 解耦rope

​ rope与低秩kv压缩并不兼容。具体来说,rope对于query和key是位置敏感的。若将rope应用在 k t c \textbf{k}_t^c ktc上,等式10中的 w u k w^{uk} wuk将与位置敏感rope矩阵耦合。但是在推理时, w u k w^{uk} wuk就无法被吸收到 w q w^q wq中,因为对当前生成token相关的rope矩阵将位于 w q w^q wq w u k w^{uk} wuk之间,而矩阵乘法不满足交换律。因此,推理时必须重新计算前面token的key,这会显著影响推理效率。

​ 为了解决这个问题,提出使用额外的多头query q t , i r ∈ r d h r \textbf{q}_{t,i}^r\in\mathbb{r}^{d_h^r} qt,irrdhr和共享key k t r ∈ r d h r \textbf{k}_t^r\in\mathbb{r}^{d_h^r} ktrrdhr来携带rope,其中 d h r d_h^r dhr表示解耦query和key的每个头的维度。在mla中使用解耦rope策略的方式为:
q t r = [ q t , 1 r ; q t , 2 r ; …   ; q t , n h r ] = rope ( w q r c t q ) k t r = rope ( w k r h t ) q t , i = [ q t , i c ; q t , i r ] k t , i = [ k t , i c ; k t r ] o t , i = ∑ j = 1 t softmax j ( q t , i ⊤ k j , i d h + d h r ) v j , i c u t = w o [ o t , 1 ; o t , 2 ; …   ; o t , n h ] \begin{align} \textbf{q}_t^r&=[\textbf{q}_{t,1}^r;\textbf{q}_{t,2}^r;\dots;\textbf{q}_{t,n_h}^r]=\text{rope}(w^{qr}\textbf{c}_t^q) \tag{14}\\ \textbf{k}_t^r&=\text{rope}(w^{kr}\textbf{h}_t) \tag{15}\\ \textbf{q}_{t,i}&=[\textbf{q}_{t,i}^c;\textbf{q}_{t,i}^r] \tag{16}\\ \textbf{k}_{t,i}&=[\textbf{k}_{t,i}^c;\textbf{k}_t^r] \tag{17} \\ \textbf{o}_{t,i}&=\sum_{j=1}^t\text{softmax}_j(\frac{\textbf{q}_{t,i}^\top\textbf{k}_{j,i}}{\sqrt{d_h+d_h^r}})\textbf{v}_{j,i}^c \tag{18} \\ \textbf{u}_t&=w^o[\textbf{o}_{t,1};\textbf{o}_{t,2};\dots;\textbf{o}_{t,n_h}] \tag{19}\\ \end{align} \\ qtrktrqt,ikt,iot,iut=[qt,1r;qt,2r;;qt,nhr]=rope(wqrctq)=rope(wkrht)=[qt,ic;qt,ir]=[kt,ic;ktr]=j=1tsoftmaxj(dh+dhr qt,ikj,i)vj,ic=wo[ot,1;ot,2;;ot,nh](14)(15)(16)(17)(18)(19)
其中 w q r ∈ r d h r n h × d c ′ w^{qr}\in\mathbb{r}^{d_h^r n_h\times d_c'} wqrrdhrnh×dc w k r ∈ r d h r × d w^{kr}\in\mathbb{r}^{d_h^r\times d} wkrrdhr×d是用于产生解耦query和key的矩阵; rope ( ⋅ ) \text{rope}(\cdot) rope()表示应用rope的操作; [ ⋅ ; ⋅ ] [\cdot;\cdot] [;]表示拼接操作。在推理时,解耦的key也需要被缓存。因此,deekseek-v2需要的总kv cache包含 ( d c + d h r ) l (d_c+d_h^r)l (dc+dhr)l个元素。

1.4 结论

在这里插入图片描述

mla能够通过更少的kv cache实现比mha更好的效果

2. 整体结构

2.1 基础结构

​ 对于ffn层,利用deepseekmoe架构,即将专家划分为更细粒度,从而获得更专业化的专家以及获取更准确的知识。在具有相同激活和总专家参数的情况下,deepseekmoe能够大幅度超越传统moe架构。

​ 令 u t \textbf{u}_t ut是第t个token对ffn的输入,那么计算ffn的输出 h t ′ \textbf{h}_t' ht为:
h t ′ = u t + ∑ i = 1 n s ffn i ( s ) ( u t ) + ∑ i = 1 n r g i , t ffn i ( r ) ( u t ) g i , t = { s i , t , s i , t ∈ topk ( { s j , t ∣ 1 ≤ j ≤ n r } , k r ) 0 , otherwise s i , t = softmax i ( u t ⊤ e i ) \begin{align} \textbf{h}_t'&=\textbf{u}_t+\sum_{i=1}^{n_s}\text{ffn}_i^{(s)}(\textbf{u}_t)+\sum_{i=1}^{n_r}g_{i,t}\text{ffn}_{i}^{(r)}(\textbf{u}_t) \tag{20}\\ g_{i,t}&=\begin{cases} s_{i,t},& s_{i,t}\in\text{topk}(\{s_{j,t}|1\leq j\leq n_r\},k_r)\\ 0,&\text{otherwise} \end{cases}\tag{21}\\ s_{i,t}&=\text{softmax}_i(\textbf{u}_t^\top \textbf{e}_i) \tag{22}\\ \end{align} \\ htgi,tsi,t=ut+i=1nsffni(s)(ut)+i=1nrgi,tffni(r)(ut)={si,t,0,si,ttopk({sj,t∣1jnr},kr)otherwise=softmaxi(utei)(20)(21)(22)
其中 n s n_s ns n r n_r nr表示共享专家和路由专家的数量; ffn i ( s ) ( ⋅ ) \text{ffn}_i^{(s)}(\cdot) ffni(s)() ffn i ( r ) ( ⋅ ) \text{ffn}_i^{(r)}(\cdot) ffni(r)()表示第i个共享专家和第i个路由专家; k r k_r kr表示激活路由专家的数量; g i , t g_{i,t} gi,t是第i个专家的门限值; e i \textbf{e}_i ei是当前层第i个路由专家的中心。

2.2 设备受限路由

​ 设计了一种设备受限路由机制来控制moe相关的通信成本。当采用专家并行时,路由专家将分布在多个设备上。对于每个token,moe相关的通信频率与目标专家覆盖的设备数量成正比。由于在deepseekmoe中细粒度专家划分,激活专家的数量会很大,因此应用专家并行时,与moe相关的通信将更加昂贵。

​ 对于deepseek-v2,除了路由专家会选择top-k个以外,还会确保每个token的目标专家最多分布在m个设备上。具体来说,对于每个token,先选择包含最高分数专家的m个设备。然后在这m个设备上执行top-k选择。在实践中,当 m ≥ 3 m\geq 3 m3时,设备受限路由能够实现与不受限top-k路由大致一致的良好性能。

2.3 用于负载均衡的辅助loss

​ 不平衡的负载会增加路由坍缩的风险,使一些专家无法得到充分的训练和利用。此外,当使用专家并行时,不平衡的负载降低计算效率。在deepseek-v2训练时,设计了三种辅助损失函数用于控制专家级别负载均衡 ( l expbal ) (\mathcal{l}_{\text{expbal}}) (lexpbal)、设备级别负载均衡 ( l devbal ) (\mathcal{l}_{\text{devbal}}) (ldevbal)和通信均衡 l commbal \mathcal{l}_{\text{commbal}} lcommbal

专家级均衡loss。专家级均衡loss用于缓解路由坍缩问题:
l expbal = α 1 ∑ i = 1 n r f i p i , f i = n r k r t ∑ t = 1 t 1 (token t selects expert i) p i = 1 t ∑ t = 1 t s i , t \begin{align} \mathcal{l}_{\text{expbal}}&=\alpha_1\sum_{i=1}^{n_r}f_ip_i, \tag{23} \\ f_i&=\frac{n_r}{k_r t}\sum_{t=1}^t\mathbb{1}\text{(token t selects expert i)} \tag{24} \\ p_i&=\frac{1}{t}\sum_{t=1}^t s_{i,t} \tag{25} \\ \end{align} \\ lexpbalfipi=α1i=1nrfipi,=krtnrt=1t1(token t selects expert i)=t1t=1tsi,t(23)(24)(25)
其中 α 1 \alpha_1 α1是称为专家级均衡因子的超参数; 1 ( ⋅ ) \mathbb{1}(\cdot) 1()是指示函数; t t t是序列中token的数量。

设备级均衡loss。除了专家级均衡loss以外,也设计了设备级别均衡loss来确保跨设备均衡计算。在deepseek-v2训练过程中,将所有的专家划分至 d d d { e 1 , e 2 , … , e d } \{\mathcal{e}_1,\mathcal{e}_2,\dots,\mathcal{e}_d\} {e1,e2,,ed}并在单个设备上部署每个组。设备级均衡loss计算如下:
l devbal = α 2 ∑ i = 1 d f i ′ p i ′ f i ′ = 1 e i ∑ j ∈ e i f j p i ′ = ∑ j ∈ e i p j \begin{align} \mathcal{l}_{\text{devbal}}&=\alpha_2\sum_{i=1}^d f_i' p_i'\tag{26} \\ f_i'&=\frac{1}{\mathcal{e}_i}\sum_{j\in\mathcal{e}_i}f_j \tag{27} \\ p_i'&=\sum_{j\in\mathcal{e}_i}p_j \tag{28} \\ \end{align} \\ ldevbalfipi=α2i=1dfipi=ei1jeifj=jeipj(26)(27)(28)
其中 α 2 \alpha_2 α2是称为设备级均衡因子的超参数。

通信均衡loss。通信均衡loss能够确保每个设备通信的均衡。虽然设备限制路由机制能够确保每个设备发送信息有上限,但是当某个设备比其他设备接收更多的tokens,那么实际通信效率将会有影响。为了缓解这个问题,设计了一种通信均衡loss如下:
l commbal = α 3 ∑ t = 1 d f i ′ ′ p i ′ ′ f i ′ ′ = d m t ∑ t = 1 t 1 (token t is sent to device i) p i ′ ′ = ∑ j ∈ e i p j \begin{align} \mathcal{l}_{\text{commbal}}&=\alpha_3\sum_{t=1}^d f_i''p_i''\tag{29} \\ f_i''&=\frac{d}{mt}\sum_{t=1}^t\mathbb{1}\text{(token t is sent to device i)}\tag{30} \\ p_i''&=\sum_{j\in\mathcal{e}_i}p_j\tag{31} \\ \end{align} \\ lcommbalfi′′pi′′=α3t=1dfi′′pi′′=mtdt=1t1(token t is sent to device i)=jeipj(29)(30)(31)
其中 α 3 \alpha_3 α3是称为通信均衡因子的超参数。设备受限路由机制操作主要确保每个设备至多向其他设备传输mt个hidden states。同时,通信均衡loss用来鼓励每个设备从其他设备接受mt个hidden states。通信均衡loss确保设备间信息均衡交换,实现高效通信。

2.4 token-dropping策略

​ 虽然均衡loss的目标是确保均衡负载,但是其并不能严格确保负载均衡。为了进一步缓解由于不均衡导致的计算浪费,在训练时引入了设备级别的token-dropping策略。该方法会先计算每个设备的平均计算预算,这意味着每个设备的容量因子等于1.0。然而,在每个设备上drop具有最低affinity分数的token,直到达到计算预算。此外,确保大约10%的训练序列的token永远不会被drop。这样,可以根据效率要求灵活地决定是否在推理过程中drop token,并确保训练和推理的一致性。

三、预训练

1. 实验设置

1.1 数据构造

​ 数据处理过程同deepseek 67b,并进一步扩展数据量和质量。采用与deepseek 67b相同的tokenizer。预训练语料包含8.1t tokens,中文token比英文多12%。

1.2 超参数

​ 略

1.3 infrastructures

​ deepseek-v2训练基于hai-llm框架。利用16路0气泡流水并行、8路专家并行和zero-1数据并行。考虑到deepseek-v2具有相对较少的激活参数,并且对一部分操作进行重计算来节约激活显存,因此可以不使用张量并行,从而降低通信开销。此外,为了进一步提高训练效率,使用专家并行all-to-all通信来重叠共享专家的计算。使用定制化的cuda核来改善通信、路由算法和不同专家之间融合线性计算。此外,mla基于改善版本的flashattention-2进行优化。

1.4 长上下文扩展

​ 使用yarn将上下文窗口尺寸从4k扩展至128k。

2. 评估

在这里插入图片描述

四、对齐

sft。 使用了150万样本的微调数据,其中120万是用于有用性,30万则用于安全性。

强化学习。仍然采用grpo。

结果
在这里插入图片描述

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com