当前位置: 代码网 > it编程>游戏开发>ar > Llama改进之——SwiGLU激活函数

Llama改进之——SwiGLU激活函数

2024年08月02日 ar 我要评论
本文介绍如何实现SwiGLU。

引言

今天介绍llama模型引入的关于激活函数的改进——swiglu1,该激活函数取得了不错的效果,得到了广泛地应用。

swiglu是glu的一种变体,其中包含了glu和swish激活函数。

glu

glu(gated linear units,门控线性单元)2引入了两个不同的线性层,其中一个首先经过sigmoid函数,其结果将和另一个线性层的输出进行逐元素相乘作为最终的输出:
glu ( x , w , v , b , c ) = σ ( x w + b ) ⊗ ( x v + c ) (1) \text{glu}(x,w,v,b,c) = \sigma(xw+b) \otimes (xv+c) \tag 1 glu(x,w,v,b,c)=σ(xw+b)(xv+c)(1)
这里 w , v w,v w,v以及 b , c b,c b,c分别是这两个线性层的参数; σ ( x w + b ) \sigma(xw+b) σ(xw+b)作为门控,控制 x v + c xv+c xv+c的输出。

这里使用 σ \sigma σ作为激活函数,修改改激活函数得到的变体通常能带来更好的性能表现,比如swiglu修改激活函数为swish。我们来看下swish激活函数。

swish

swish3激活函数的形式为:
swish β ( x ) = x σ ( β x ) (2) \text{swish}_\beta(x) = x \sigma(\beta x) \tag 2 swishβ(x)=xσ(βx)(2)
其中 σ ( x ) \sigma(x) σ(x)是sigmoid函数; β \beta β是一个可学习的参数。

可以通过下面的代码画出swish激活函数在不同参数 β \beta β下的图像:

import numpy as np
import matplotlib.pyplot as plt

def swish(x, beta):
  return x / (1 + np.exp(-beta*x))

x = np.linspace(-10, 10, 100)
betas = [0.1, 1.0, 10.0]

plt.figure(figsize=(10, 6))

for beta in betas:
    y = swish(x, beta)
    plt.plot(x, y, label=f'beta={beta}')

plt.legend()
plt.title('swish activation function')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.grid(true)
plt.show()

image-20240428224729925

可以看到3,当 β \beta β趋近于 0 0 0时,swish函数趋近于线性函数 y = x 2 y=x^2 y=x2;当 β \beta β趋近于无穷大时,swish函数趋近于relu函数;当 β \beta β取值为 1 1 1时,swish函数是光滑且非单调的,等价于参考4中介绍的silu。

swish与relu之间最显著的区别是当 x < 0 x < 0 x<0时swish的非单调“凸起”3

swiglu

如前文所述,将公式(1)中glu的激活函数改为swish即变成了所谓的swiglu激活函数1
swiglu ( x , w , v ) = swish β ( x w ) ⊗ ( x v ) (3) \text{swiglu}(x,w,v) = \text{swish}_\beta(xw) \otimes (xv) \tag{3} swiglu(x,w,v)=swishβ(xw)(xv)(3)
这里省略了偏置项。

代码实现

参考llama,全连接层使用带有swiglu激活函数的ffn(position-wise feed-forward network)的公式如下1
ffn swiglu ( x , w , v , w 2 ) = ( swish 1 ( x w ) ⊗ x v ) w 2 (4) \text{ffn}_{\text{swiglu}}(\pmb x,w,v,w_2) = (\text{swish}_1(\pmb xw) \otimes \pmb xv)w_2 \tag 4 ffnswiglu(x,w,v,w2)=(swish1(xw)xv)w2(4)
这里的swish函数可以被silu函数替代:
silu ( x ) = x σ ( x ) \text{silu}(\pmb x) = \pmb x \sigma(\pmb x) silu(x)=xσ(x)
即:
ffn swiglu ( x , w , v , w 2 ) = ( silu ( x w ) ⊗ x v ) w 2 (5) \text{ffn}_{\text{swiglu}}(\pmb x,w,v,w_2) = (\text{silu}(\pmb xw) \otimes \pmb xv)w_2 \tag 5 ffnswiglu(x,w,v,w2)=(silu(xw)xv)w2(5)

import torch
from torch import nn
import torch.nn.functional as f

class feedforward(nn.module):
    def __init__(self, hidden_size: int, intermediate_size: int) -> none:
       	super().__init__()

        self.w1 = nn.linear(hidden_size, intermediate_size, bias=false)
        self.w2 = nn.linear(intermediate_size, hidden_size, bias=false)
        self.w3 = nn.linear(hidden_size, intermediate_size, bias=false)
        
    def forward(self, x: torch.tensor) -> torch.tensor:
        # x: (batch_size, seq_len, hidden_size)
        # w1(x) -> (batch_size, seq_len, intermediate_size)
        # w1(x) -> (batch_size, seq_len, intermediate_size)
        # w2(*) -> (batch_size, seq_len, hidden_size)
    	return self.w2(f.silu(self.w1(x)) * self.w3(x))
            

这里w1,w2,w3分别对应公式(5)中的 w , w 2 , v w,w_2,v w,w2,v

注意维度,其中w1,w3x转换到维度intermediate_size,然后w2转换回hidden_size

参考


  1. [论文翻译]glu variants improve transformer ↩︎ ↩︎ ↩︎

  2. [论文笔记]language modeling with gated convolutional networks ↩︎

  3. [论文笔记]searching for activation functions ↩︎ ↩︎ ↩︎

  4. [论文笔记]gaussian error linear units (gelus) ↩︎

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com