当前位置: 代码网 > it编程>编程语言>其他编程 > LLaMa系列模型详解(原理介绍、代码解读):LLaMa

LLaMa系列模型详解(原理介绍、代码解读):LLaMa

2024年07月31日 其他编程 我要评论
下图很直观的展示了旋转变换的过程:旋转编码 RoPE 可以有效地保持位置信息的相对关系,

llama详解

llama(large language model meta ai)是由meta(前身为facebook)开发的一种大规模语言模型,旨在提高自然语言处理(nlp)任务的性能。llama基于变换器(transformer)架构,并经过大规模数据训练,以便在多种语言任务中表现出色。

meta ai认为:对于给定的计算预算,最佳性能不是通过最大的模型实现的,而是通过在更多数据上训练的较小模型实现的。

模型结构

与gpt等生成模型类似,llama也只使用了transformer的解码器,但基于transformer进行了三个改进:

  1. 使用了gpt3的预标准化。为了提高训练稳定性,对每个transformer子层的输入进行归一化,而不是对输出进行归一化。使用由rmsnorm 归一化函数。
  2. 用 swiglu 激活函数替换 relu 非线性,以提高性能。使用 2 3 4 d \frac{2}{3}4d 324d的维度代替palm中的 4 d 4d 4d
  3. 类似gptneo,删除了绝对位置嵌入,而是添加了旋转位置嵌入(rope)。

下面逐一介绍这三个改进:

rmsnorm

rmsnorm(root mean square normalization)是一种归一化技术,用于稳定和加速神经网络的训练过程。与其他归一化方法(如batchnorm和layernorm)不同,rmsnorm通过计算输入张量的均方根(rms)来进行归一化。rmsnorm公式如下:
rmsnorm ( x ) = x 1 d ∑ i = 1 d x i 2 + ϵ ⋅ γ \text{rmsnorm}(x) = \frac{x}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}} \cdot \gamma rmsnorm(x)=d1i=1dxi2+ϵ xγ
其中 x x x是输入向量, d d d 是输入向量的维度, ϵ \epsilon ϵ是一个小常数,用于避免除零错误, γ \gamma γ是一个可学习的缩放参数。

llama中的实现如下:

class rmsnorm(torch.nn.module):  
    def __init__(self, dim: int, eps: float = 1e-6):  
        super().__init__()  
        self.eps = eps  
        self.weight = nn.parameter(torch.ones(dim))  
  
    def _norm(self, x):  
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=true) + self.eps)  
  
    def forward(self, x):  
        output = self._norm(x.float()).type_as(x)  
        return output * self.weight

swiglu激活函数

swiglu (swish-gated linear unit) 是一种用于神经网络的激活函数,它结合了swish激活函数和门控机制,能够有效地增强模型的表达能力和性能。公式如下:
swiglu ( x ) = swish ( x ) ⋅ ( gated linear unit ( x ) ) \text{swiglu}(x) = \text{swish}(x) \cdot (\text{gated linear unit}(x)) swiglu(x)=swish(x)(gated linear unit(x))
swish ( x ) = x ⋅ σ ( x ) \text{swish}(x) = x \cdot \sigma(x) swish(x)=xσ(x)
gated linear unit ( x ) = linear 1 ( x ) ⋅ σ ( linear 2 ( x ) ) \text{gated linear unit}(x) = \text{linear}_1(x) \cdot \sigma(\text{linear}_2(x)) gated linear unit(x)=linear1(x)σ(linear2(x))
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1

linear 1 \text{linear}_1 linear1 linear 2 \text{linear}_2 linear2是两个单独的线性变换。

llama代码中使用 f . s i l u ( x ) f.silu(x) f.silu(x)添加swiglu激活函数

rope

旋转位置嵌入(rotary position embedding, rope)是一种为序列模型(如transformer)提供位置编码的方法。rope通过将输入向量在复数域进行旋转变换,来编码序列中位置的信息。与传统的位置编码方法(如正弦-余弦位置编码)相比,rope能够更好地捕捉序列中的相对位置信息,提高模型的表现力。

旋转位置嵌入(rope)是一种为序列模型提供位置编码的方法。其通过将输入向量在复数域进行旋转变换来编码位置信息。以下是rope的具体实现步骤:

  1. 频率向量的计算:
    f i = 1 θ 2 i d f_i = \frac{1}{\theta^{\frac{2i}{d}}} fi=θd2i1
    其中 θ \theta θ是一个常数(通常取 10000), i i i是向量维度的索引。

  2. 旋转角度的计算:
    angle ( t ) = t ⋅ f i \text{angle}(t) = t \cdot f_i angle(t)=tfi
    其中 t t t是位置索引。

  3. 应用旋转变换:
    对每个位置 t t t的输入向量 x t x_t xt,在复数域进行旋转变换:
    x t ′ = x t ⋅ e j ⋅ angle ( t ) x_t' = x_t \cdot e^{j \cdot \text{angle}(t)} xt=xtejangle(t)
    对于位置编码,常规的做法是在计算 query,key 和 value 向量之前,会计算一个位置编码向量 加到词嵌入上,位置编码向量同样也是维向量,然后再乘以对应的变换矩阵。

rope 的 self-attention 操作的流程是:对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量,然后对每个 token 位置都计算对应的旋转位置编码,接着对每个 token 位置的 query 和 key 向量的元素按照两两一组应用旋转变换,最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果。

下图很直观的展示了旋转变换的过程:
image.png

旋转编码 rope 可以有效地保持位置信息的相对关系,即相邻位置的编码之间有一定的相似性,而远离位置的编码之间有一定的差异性。 这样可以增强模型对位置信息的感知和利用。这一点是其他绝对位置编码方式(如正弦位置编码、学习的位置编码等)所不具备的,因为它们只能表示绝对位置,而不能表示相对位置。

下面这篇文章给出了公式原理和推导,讲解十分详细:点击此处

在llama中,rope使用下面的方式实现:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
    t = torch.arange(end, device=freqs.device)  # type: ignore  
    freqs = torch.outer(t, freqs).float()  # type: ignore  
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
    return freqs_cis  
  
  
def reshape_for_broadcast(freqs_cis: torch.tensor, x: torch.tensor):  
    ndim = x.ndim  
    assert 0 <= 1 < ndim  
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
    return freqs_cis.view(*shape)  
  
  
def apply_rotary_emb(  
    xq: torch.tensor,  
    xk: torch.tensor,  
    freqs_cis: torch.tensor,  
) -> tuple[torch.tensor, torch.tensor]:  
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
    return xq_out.type_as(xq), xk_out.type_as(xk)

下面的代码给出了加入旋转位置嵌入的注意力机制:

class attention(nn.module):  
    def __init__(self, args: modelargs):  
        super().__init__()  
  
        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  
        self.head_dim = args.dim // args.n_heads  
  
        self.wq = columnparallellinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=false,  
            gather_output=false,  
            init_method=lambda x: x,  
        )  
        self.wk = columnparallellinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=false,  
            gather_output=false,  
            init_method=lambda x: x,  
        )  
        self.wv = columnparallellinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=false,  
            gather_output=false,  
            init_method=lambda x: x,  
        )  
        self.wo = rowparallellinear(  
            args.n_heads * self.head_dim,  
            args.dim,  
            bias=false,  
            input_is_parallel=true,  
            init_method=lambda x: x,  
        )  
  
        self.cache_k = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
        self.cache_v = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
  
    def forward(self, x: torch.tensor, start_pos: int, freqs_cis: torch.tensor, mask: optional[torch.tensor]):  
        bsz, seqlen, _ = x.shape  
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  
        self.cache_k = self.cache_k.to(xq)  
        self.cache_v = self.cache_v.to(xq)  
  
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  
        keys = self.cache_k[:bsz, : start_pos + seqlen]  
        values = self.cache_v[:bsz, : start_pos + seqlen]  
  
        xq = xq.transpose(1, 2)  
        keys = keys.transpose(1, 2)  
        values = values.transpose(1, 2)  
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
        if mask is not none:  
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  
        scores = f.softmax(scores.float(), dim=-1).type_as(xq)  
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  
        output = output.transpose(  
            1, 2  
        ).contiguous().view(bsz, seqlen, -1)  
  
        return self.wo(output)

接下来给出llama实现的全部代码:

# copyright (c) meta platforms, inc. and affiliates.  
# this software may be used and distributed according to the terms of the gnu general public license version 3.  
  
from typing import optional, tuple  
from dataclasses import dataclass  
import math  
  
import torch  
from torch import nn  
import torch.nn.functional as f  
  
import fairscale.nn.model_parallel.initialize as fs_init  
from fairscale.nn.model_parallel.layers import (  
    parallelembedding,  
    rowparallellinear,  
    columnparallellinear,  
)  
  
  
@dataclass  
class modelargs:  
    dim: int = 512  
    n_layers: int = 8  
    n_heads: int = 8  
    vocab_size: int = -1  # defined later by tokenizer  
    multiple_of: int = 256  # make swiglu hidden layer size multiple of large power of 2  
    norm_eps: float = 1e-5  
  
    max_batch_size: int = 32  
    max_seq_len: int = 2048  
  
  
class rmsnorm(torch.nn.module):  
    def __init__(self, dim: int, eps: float = 1e-6):  
        super().__init__()  
        self.eps = eps  
        self.weight = nn.parameter(torch.ones(dim))  
  
    def _norm(self, x):  
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=true) + self.eps)  
  
    def forward(self, x):  
        output = self._norm(x.float()).type_as(x)  
        return output * self.weight  
  
  
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
    t = torch.arange(end, device=freqs.device)  # type: ignore  
    freqs = torch.outer(t, freqs).float()  # type: ignore  
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
    return freqs_cis  
  
  
def reshape_for_broadcast(freqs_cis: torch.tensor, x: torch.tensor):  
    ndim = x.ndim  
    assert 0 <= 1 < ndim  
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
    return freqs_cis.view(*shape)  
  
  
def apply_rotary_emb(  
    xq: torch.tensor,  
    xk: torch.tensor,  
    freqs_cis: torch.tensor,  
) -> tuple[torch.tensor, torch.tensor]:  
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
    return xq_out.type_as(xq), xk_out.type_as(xk)  
  
  
class attention(nn.module):  
    def __init__(self, args: modelargs):  
        super().__init__()  
  
        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  
        self.head_dim = args.dim // args.n_heads  
  
        self.wq = columnparallellinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=false,  
            gather_output=false,  
            init_method=lambda x: x,  
        )  
        self.wk = columnparallellinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=false,  
            gather_output=false,  
            init_method=lambda x: x,  
        )  
        self.wv = columnparallellinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=false,  
            gather_output=false,  
            init_method=lambda x: x,  
        )  
        self.wo = rowparallellinear(  
            args.n_heads * self.head_dim,  
            args.dim,  
            bias=false,  
            input_is_parallel=true,  
            init_method=lambda x: x,  
        )  
  
        self.cache_k = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
        self.cache_v = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
  
    def forward(self, x: torch.tensor, start_pos: int, freqs_cis: torch.tensor, mask: optional[torch.tensor]):  
        bsz, seqlen, _ = x.shape  
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  
        self.cache_k = self.cache_k.to(xq)  
        self.cache_v = self.cache_v.to(xq)  
  
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  
        keys = self.cache_k[:bsz, : start_pos + seqlen]  
        values = self.cache_v[:bsz, : start_pos + seqlen]  
  
        xq = xq.transpose(1, 2)  
        keys = keys.transpose(1, 2)  
        values = values.transpose(1, 2)  
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
        if mask is not none:  
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  
        scores = f.softmax(scores.float(), dim=-1).type_as(xq)  
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  
        output = output.transpose(  
            1, 2  
        ).contiguous().view(bsz, seqlen, -1)  
  
        return self.wo(output)  
  
  
class feedforward(nn.module):  
    def __init__(  
        self,  
        dim: int,  
        hidden_dim: int,  
        multiple_of: int,  
    ):  
        super().__init__()  
        hidden_dim = int(2 * hidden_dim / 3)  
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  
  
        self.w1 = columnparallellinear(  
            dim, hidden_dim, bias=false, gather_output=false, init_method=lambda x: x  
        )  
        self.w2 = rowparallellinear(  
            hidden_dim, dim, bias=false, input_is_parallel=true, init_method=lambda x: x  
        )  
        self.w3 = columnparallellinear(  
            dim, hidden_dim, bias=false, gather_output=false, init_method=lambda x: x  
        )  
  
    def forward(self, x):  
        return self.w2(f.silu(self.w1(x)) * self.w3(x))  
  
  
class transformerblock(nn.module):  
    def __init__(self, layer_id: int, args: modelargs):  
        super().__init__()  
        self.n_heads = args.n_heads  
        self.dim = args.dim  
        self.head_dim = args.dim // args.n_heads  
        self.attention = attention(args)  
        self.feed_forward = feedforward(  
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of  
        )  
        self.layer_id = layer_id  
        self.attention_norm = rmsnorm(args.dim, eps=args.norm_eps)  
        self.ffn_norm = rmsnorm(args.dim, eps=args.norm_eps)  
  
    def forward(self, x: torch.tensor, start_pos: int, freqs_cis: torch.tensor, mask: optional[torch.tensor]):  
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)  
        out = h + self.feed_forward.forward(self.ffn_norm(h))  
        return out  
  
  
class transformer(nn.module):  
    def __init__(self, params: modelargs):  
        super().__init__()  
        self.params = params  
        self.vocab_size = params.vocab_size  
        self.n_layers = params.n_layers  
  
        self.tok_embeddings = parallelembedding(  
            params.vocab_size, params.dim, init_method=lambda x: x  
        )  
  
        self.layers = torch.nn.modulelist()  
        for layer_id in range(params.n_layers):  
            self.layers.append(transformerblock(layer_id, params))  
  
        self.norm = rmsnorm(params.dim, eps=params.norm_eps)  
        self.output = columnparallellinear(  
            params.dim, params.vocab_size, bias=false, init_method=lambda x: x  
        )  
  
        self.freqs_cis = precompute_freqs_cis(  
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2  
        )  
  
    @torch.inference_mode()  
    def forward(self, tokens: torch.tensor, start_pos: int):  
        _bsz, seqlen = tokens.shape  
        h = self.tok_embeddings(tokens)  
        self.freqs_cis = self.freqs_cis.to(h.device)  
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]  
  
        mask = none  
        if seqlen > 1:  
            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)  
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)  
  
        for layer in self.layers:  
            h = layer(h, start_pos, freqs_cis, mask)  
        h = self.norm(h)  
        output = self.output(h[:, -1, :])  # only compute last logits  
        return output.float()
(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com