当前位置: 代码网 > 科技>操作系统>Windows > 人工智能|深度学习——多模态条件机制 Cross Attention 原理及实现

人工智能|深度学习——多模态条件机制 Cross Attention 原理及实现

2024年07月31日 Windows 我要评论
虽然之前写过 Attention 的文章,但现在回头看之前写的一些文章,感觉都好啰嗦,正好下一篇要写的 Stable Diffusion 中有 cross-attention,索性就再单拎出来简单说一下 Attention 吧,那么这篇文章的作用有两个:第一是为 Stable Diffusion 做补充,第二是为后续的 Vision Transformer 和 Swin Transformer 做铺垫。

一、引入

二、attention 思想

我们以机器翻译为例进一步加深理解,假设有文本“汤姆追逐杰瑞”,方便起见我们规定词库单词就为tom、chase、jerry,当我们对“汤姆”进行翻译的时候,套用上述 attention 机制:

三、self attention

四、multi-head attention

图片

五、padding mask

# 构建padding mask矩阵
pad_mask = input_ids.eq(0) # 逻辑矩阵pad_mask:将填充位置标记为true,其他位置标记为false [batch_size, seq_len]
# 增加维度,和 qk^t 后的att权重维度等同 [batch_szie, seq_len, seq_len]
pad_mask = pad_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) 

# (batch_size, num_heads, seq_len, seq_len)
att_weights = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(d_k) # 点积操作

# 因为是多头,所以mask矩阵维度要扩充到4维  [batch_size, seq_len, seq_len] -> [batch_size, nums_head, seq_len, seq_len]
pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
att_weights.masked_fill_(pad_mask, float('-inf')) # 将填充位置对应的元素设置为负无穷
att_weights = torch.softmax(att_weights, dim=-1) # 在最后一个维度上进行softmax

context = torch.matmul(att_weights, v) # (batch_size, num_heads, seq_len, emb_dim)

六、cross attention

七、代码实现

7.1 self-attention

class selfattention(nn.module):
    def __init__(self, emb_dim):
        super(selfattention, self).__init__()
        self.emb_dim = emb_dim

        self.wq = nn.linear(emb_dim, emb_dim, bias=false)
        self.wk = nn.linear(emb_dim, emb_dim, bias=false)
        self.wv = nn.linear(emb_dim, emb_dim, bias=false)

        self.fc = nn.linear(emb_dim, emb_dim)

    def forward(self, x, pad_mask=none):
        # [batch_szie, seq_len, emb_dim] = [3, 5, 512]

        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        att_weights = torch.bmm(q, k.transpose(1, 2))   # [batch_szie, seq_len, seq_len] = [3, 5, 5]
        att_weights = att_weights / math.sqrt(self.emb_dim)

        if pad_mask is not none:
            att_weights = att_weights.masked_fill(pad_mask, -1e9)

        att_weights = f.softmax(att_weights, dim=-1)
        output = torch.bmm(att_weights, v)   # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
        output = self.fc(output)

        return output, att_weights

7.2 multi-head attention

class multiheadattention(nn.module):
    def __init__(self, emb_dim, num_heads, att_dropout=0.0):
        super(multiheadattention, self).__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.att_dropout = att_dropout

        assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
        self.depth = emb_dim // num_heads

        self.wq = nn.linear(emb_dim, emb_dim, bias=false)
        self.wk = nn.linear(emb_dim, emb_dim, bias=false)
        self.wv = nn.linear(emb_dim, emb_dim, bias=false)

        self.fc = nn.linear(emb_dim, emb_dim)

    def forward(self, x, pad_mask=none):
        # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
        batch_size = x.size(0)

        # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        # 分头 [batch_szie, num_heads, seq_len, depth] = [3, 8, 5, 512/8=64]
        q = q.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)

        # [batch_szie, num_heads, seq_len, seq_len] = [3, 8, 5, 5]
        att_weights = torch.matmul(q, k.transpose(-2, -1))
        att_weights = att_weights / math.sqrt(self.depth)

        if pad_mask is not none:
            # 因为是多头,所以mask矩阵维度要扩充到4维  [batch_size, seq_len, seq_len] -> [batch_size, nums_head, seq_len, seq_len]
            pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
            att_weights = att_weights.masked_fill(pad_mask, -1e9)

        att_weights = f.softmax(att_weights, dim=-1)

        # 自己的多头注意力效果没有torch的好,我猜是因为它的dropout给了att权重,而不是fc
        if self.att_dropout > 0.0:
            att_weights = f.dropout(att_weights, p=self.att_dropout)

        # [batch_szie, num_heads, seq_len, depth] = [3, 8, 5, 64]
        output = torch.matmul(att_weights, v)

        # 不同头的结果拼接 [batch_szie, seq_len, emb_dim] = [3, 5, 512]
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.emb_dim)

        output = self.fc(output)

        return output, att_weights

7.3 cross_multiattention

class cross_multiattention(nn.module):
    def __init__(self, in_channels, emb_dim, num_heads, att_dropout=0.0, aropout=0.0):
        super(cross_multiattention, self).__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.scale = emb_dim ** -0.5

        assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
        self.depth = emb_dim // num_heads


        self.proj_in = nn.conv2d(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)

        self.wq = nn.linear(emb_dim, emb_dim)
        self.wk = nn.linear(emb_dim, emb_dim)
        self.wv = nn.linear(emb_dim, emb_dim)

        self.proj_out = nn.conv2d(emb_dim, in_channels, kernel_size=1, stride=1, padding=0)


    def forward(self, x, context, pad_mask=none):
        '''

        :param x: [batch_size, c, h, w]
        :param context: [batch_szie, seq_len, emb_dim]
        :param pad_mask: [batch_size, seq_len, seq_len]
        :return:
        '''
        b, c, h, w = x.shape

        x = self.proj_in(x)   # [batch_size, c, h, w] = [3, 512, 512, 512]
        x = rearrange(x, 'b c h w -> b (h w) c')   # [batch_size, h*w, c] = [3, 262144, 512]

        q = self.wq(x)  # [batch_size, h*w, emb_dim] = [3, 262144, 512]
        k = self.wk(context)  # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
        v = self.wv(context)

        q = q.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)  # [batch_size, num_heads, h*w, depth]
        k = k.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)  # [batch_size, num_heads, seq_len, depth]
        v = v.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)

        # [batch_size, num_heads, h*w, seq_len]
        att_weights = torch.einsum('bnid,bnjd -> bnij', q, k)
        att_weights = att_weights * self.scale

        if pad_mask is not none:
            # 因为是多头,所以mask矩阵维度要扩充到4维  [batch_size, h*w, seq_len] -> [batch_size, nums_head, h*w, seq_len]
            pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
            att_weights = att_weights.masked_fill(pad_mask, -1e9)

        att_weights = f.softmax(att_weights, dim=-1)
        out = torch.einsum('bnij, bnjd -> bnid', att_weights, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.emb_dim)   # [batch_size, h*w, emb_dim]

        print(out.shape)

        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w)   # [batch_size, c, h, w]
        out = self.proj_out(out)   # [batch_size, c, h, w]

        return out, att_weights

7.4 cross attention

class crossattention(nn.module):
    def __init__(self, in_channels, emb_dim, att_dropout=0.0, aropout=0.0):
        super(crossattention, self).__init__()
        self.emb_dim = emb_dim
        self.scale = emb_dim ** -0.5

        self.proj_in = nn.conv2d(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)

        self.wq = nn.linear(emb_dim, emb_dim)
        self.wk = nn.linear(emb_dim, emb_dim)
        self.wv = nn.linear(emb_dim, emb_dim)

        self.proj_out = nn.conv2d(emb_dim, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x, context, pad_mask=none):
        '''

        :param x: [batch_size, c, h, w]
        :param context: [batch_szie, seq_len, emb_dim]
        :param pad_mask: [batch_size, seq_len, seq_len]
        :return:
        '''
        b, c, h, w = x.shape

        x = self.proj_in(x)   # [batch_size, c, h, w] = [3, 512, 512, 512]
        x = rearrange(x, 'b c h w -> b (h w) c')   # [batch_size, h*w, c] = [3, 262144, 512]

        q = self.wq(x)  # [batch_size, h*w, emb_dim] = [3, 262144, 512]
        k = self.wk(context)  # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
        v = self.wv(context)

        # [batch_size, h*w, seq_len]
        att_weights = torch.einsum('bid,bjd -> bij', q, k)
        att_weights = att_weights * self.scale

        if pad_mask is not none:
            # [batch_size, h*w, seq_len]
            att_weights = att_weights.masked_fill(pad_mask, -1e9)

        att_weights = f.softmax(att_weights, dim=-1)
        out = torch.einsum('bij, bjd -> bid', att_weights, v)   # [batch_size, h*w, emb_dim]

        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w)   # [batch_size, c, h, w]
        out = self.proj_out(out)   # [batch_size, c, h, w]

        print(out.shape)

        return out, att_weights

7.5 main

# coding:utf-8
# @email: wangguisen@donews.com
# @time: 2023/3/22 22:58
# @file: att_test.py
'''
self attention
multi-head attention
cross attention
'''
import torch
import torch.nn as nn
import torch.nn.functional as f
import math
from einops import rearrange, repeat
from torch.nn import multiheadattention

if __name__ == '__main__':
    '''
    '''

    '''
    假设词表映射后输入 
    batch_size = 3
    seq_len = max_len = 5
    pad = 0
    emb_dim = 512
    '''
    batch_size = 3
    seq_len = 5
    emb_dim = 512
    # 本例子则词表大小为 301
    vocab_size = 301

    input_ids = torch.tensor([[100, 200, 300, 300, 0],
                 [22, 33, 44, 0, 0],
                 [66, 55, 66, 30, 0]], dtype=torch.long)

    pad_mask = input_ids.eq(0)  # 逻辑矩阵pad_mask:将填充位置标记为true,其他位置标记为false
    # pad_mask = pad_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len] = [3, 5, 5]

    inputs = nn.embedding(vocab_size, embedding_dim=emb_dim)(input_ids)   # [batch_szie, seq_len, emb_dim] = [3, 5, 512]

    # self_att = selfattention(emb_dim=emb_dim)
    # self_att(inputs, pad_mask=pad_mask)

    # multi_att = multiheadattention(emb_dim=emb_dim, num_heads=8)
    # multi_att(inputs, pad_mask=pad_mask)

    # 定义图片数据  [batch_size, c, h, w]
    input_img = torch.randn((3, 3, 512, 512))
    pad_mask = pad_mask.unsqueeze(1).expand(batch_size, 512*512, seq_len)
    # cross_att = cross_multiattention(in_channels=3, emb_dim=emb_dim, num_heads=8, att_dropout=0.0, aropout=0.0)
    # cross_att(x=input_img, context=inputs, pad_mask=pad_mask)
    cross_att = crossattention(in_channels=3, emb_dim=emb_dim, att_dropout=0.0, aropout=0.0)
    cross_att(x=input_img, context=inputs, pad_mask=pad_mask)

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com