一.引言

前面我们基于 LLM 大模型源码介绍了 Causal Mask 以及 ROPE 旋转位置编码的实现,本文介绍源码中 Transformer 的实现流程,我们基于代码逐行分析维度变化与代码含义,希望能够清晰的了解 LLM 中 Transformer 运行的流程。

二.Transformer 分层维度

上面这个 Transformer 的基础结构我们在之前已经提到过很多次,这里结合维度变化再啰嗦一次,更详细的介绍可以参考: LLM - Transformer && LLaMA2 结构分析与 LoRA 详解

1.单条样本

- Embedding Layer

对于一个典型的 LLM 大模型,输入 Embedding 层的维度 d_model 通常指的是将输入的标记 token 通过一个 embedding 层映射转换为连续向量的维度。例如,在 BERT-base 模型中,d_model 是 768,而在当下大模型中 d_model 为 8192。

- Transformer Layer

Transformer 层的输出维度通常和输入 Embedding 层的维度一致,即 d_model。如果我们持续使用 BERT-base 的例子,那么每个 Transformer 层 [ BERT中称为encoder层,LLM 中多为 decoder 层 ] 的输出也将是维度为 768 / 8192 的向量。

- lm_head Layer 

最后的 lm_head(语言模型头)的维度通常等于词汇表的大小 vocab_size,因为 lm_head 的作用是将 Transformer 层的输出转换成每个词汇的概率分布。举例来说,如果模型处理的语言的词汇表大小为 30000 个单词,那么 lm_head 的输出维度就是 30000。

- hidden_states

hidden_states 是 Transformer 模型处理过程中的一个术语,常见于模型的中间输出和内部分析。其记录了隐层的激活值,对于每个输入标记 token,Transformer 的每个层都会有一个输出向量,它表示的是在该层的特定深度上输入的表示。对于一个 N 层堆叠的 Transformer 模型,对于一个给定的输入序列,模型将会有 N 个这样的隐藏状态集。其中每个隐藏状态也会包含注意力分布,这是 Transformer 的自注意力机制的一个关键组成部分,它允许模型在处理输入时衡量不同部分之间的相互依赖性。

Tips:

假设我们有一个 BERT-base 模型,它使用 12 层 Transformer,每层的输出维度为 768,若输入一个有 5 个 tokens 的序列,每个 token 会首先被转换成一个 768 维的 embedding 向量。因此,hidden_states 在模型刚开始时会是一个形状为 (5, 768) 的张量。经过 12 层 Transformer 层处理后最后输出的 hidden_states 将会是一个形状为 (12, 5, 768) 的 3 维张量,其中包含了序列中每个token 在各个层上的表征。

2.批次样本

上面给出了单条样本的转换流程,下面我们分析下 batch_size 情况下维度的变换流程。假设我们有一个 BERT-base 模型:

词汇表大小 vocab_size = 30000

嵌入层维度 d_model = 768

堆叠层数量 N = 12

最大序列长度 max_seq_length = 128

批次大小 batch_size = 32

以下是数据通过模型时维度的具体变化过程:

- Input Layer

输入层维度为 (batch_size, max_seq_length) 即 (32, 128),每一个 128 的张量表示批次中每个序列的 token_id,即 text 通过 tokenizer 处理后的结果。 

- Embedding Layer

(bsz, max_seq_length) 的整数张量会被送入 Embedding 层,以 Bert 为例,其会被映射到 (bsz, max_seq_length, d_model) 的维度,即 (32, 128, 768)。这表明我们现在有 32 条样本,每个序列有 128 个 768 维的词嵌入向量。

- Transformer Layer

每个 Transformer 层接受一个 (bsz, max_seq_length, d_model) 的张量,经过 multi_head_attention 后输出一个相同形状的张量,这是因为 transformer 层通常会保持输入输出的维度相同,因此经过本层映射后,维度依然为 (bsz, max_seq_length, d_model) 即 (32, 128, 768)。

- lm_head Layer

lm_head 线性层将 Transformer 层的输出 (bsz, max_seq_length, d_model) 转换为 (bsz, max_seq_length, vocab_size) 的张量,即 (32, 128, 30000)。这一层一般是通过 Linear 实现的,对于复杂的 LLM,还会有 MLP 层,但最终 lm_head 的目的都是将 d_model 映射到 vocab_size,即生成一个与词汇表大小匹配的权重矩阵,代表每个 token 可能性的分布。

Tips:

如果考虑中间的 hidden_states,那么对于序列中的每个 token,在每个 Transformer 层中,我们都会得到一个 768 维的向量。因此,对于整个 batch 来说,每一层的 hidden_states 的形状为(batch_size, max_seq_length, d_model),即 (32, 128, 768)。如果我们保存所有层的hidden_states,那么我们就得到了一个形状为 (num_layers, batch_size, max_seq_length, d_model) 的 4 维张量,即 (12, 32, 128, 768),这里 num_layers 就是前面提到的 N,即 LLM 中 transformer 层堆叠的数量,这样,你就可以看到不同维度如何随着数据流通过模型而变化。这里需要注意的是真实情况下由于序列化长度可能不同,还会涉及到填充 padding 和掩码 masking 来确保批量处理是有效的,然而这并不影响上述维度变化的基本流程。

三.Transformer 维度变换

为了大家可以在本机 debug 快速测试,下面的示例我们以 Bert 及其 tokenizer 作为基模型构建 token_id 以及 Embedding,后续的 Multi-Head Attention 我们基于 Qwen 的逻辑进行了迁移,保持主体实现风格不变,更完整的代码可以参考 HF 上 modeling.py。

1.Input Layer 

输入层以及嵌入层我们通过 Bert 模型的 tokenizer 获取:

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    pretrained_bert = BertModel.from_pretrained('bert-base-uncased')

    input_texts = ["This is a test sentence.", "Here is another test sentence."]
    input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,
                                  return_tensors='pt') for text in input_texts]
    input_ids = torch.cat(input_ids, dim=0)  # Concatenate and add batch dimension

为了方便我们 input_texts 构造两条样本,所以 bsz = 2、max_length = 10、d_model = 768,input_ids 维度为 (10, ):

通过 concat 得到 (bsz, max_length) = (2, 10) 的初始维度:

tensor([[ 101, 2023, 2003, 1037, 3231, 6251, 1012,  102,    0,    0],
        [ 101, 2182, 2003, 2178, 3231, 6251, 1012,  102,    0,    0]])

2.Embedding Layer

    with torch.no_grad():
        embedded_output = pretrained_bert(input_ids)[0]  # Get the output of the BERT model
        print(embedded_output.size())  # Output shape should be (2, 10, embedding_dim)

这里通过 bert 的 Embedding 层获取 input_id 对应的 Embedding,由于 d_model = 768,所以前面 token_id 的 (bsz, max_length) 转换为 (bsz, max_length, d_model) 即 (2, 10, 768):

tensor([[[-3.7545e-02,  5.3234e-04, -1.3553e-02,  ..., -1.9545e-01,
           2.3569e-01,  4.7479e-01],
         [-7.1746e-01, -2.8763e-01,  1.4100e-01,  ..., -5.5593e-01,
           6.1830e-01,  3.9255e-01],
         [-1.9318e-01, -4.0202e-01,  3.2924e-01,  ..., -1.5206e-01,
           3.4014e-01,  1.0233e+00],
         ...,
         [ 1.5273e-01,  1.1651e-01,  1.5754e-01,  ...,  6.9833e-02,
          -8.5732e-01, -4.3875e-02],
         [ 7.0679e-02, -2.3521e-01,  6.1713e-01,  ..., -7.3852e-02,
           2.5070e-01, -6.3240e-02],
         [-1.3249e-01, -3.6026e-01,  3.5025e-01,  ..., -5.5981e-02,
           1.0420e-01, -4.3954e-01]],

        [[-2.9592e-02, -1.4164e-01, -2.2295e-03,  ..., -1.3087e-01,
           2.9421e-01,  5.5132e-01],
         [-1.0146e+00, -6.8757e-01,  1.9959e-01,  ..., -4.2000e-01,
           1.7332e-01,  9.2754e-02],
         [-1.3425e-01, -8.1044e-01,  2.6674e-01,  ...,  4.6978e-02,
          -1.0026e-01,  4.5293e-01],
         ...,
         [ 4.5527e-01,  2.2234e-02, -3.6816e-01,  ...,  4.3154e-01,
          -8.6396e-01, -2.8542e-01],
         [ 1.4188e-01, -2.4001e-01,  6.5681e-01,  ..., -5.7224e-02,
           3.1025e-01, -9.0286e-02],
         [-3.9205e-02, -3.2815e-01,  4.7910e-01,  ..., -4.7641e-02,
           2.9916e-02, -4.5328e-01]]])

3.Multi-Head Attention

        embed_dim = embedded_output.size(-1)
        num_heads = 4
        model = BITDDDAttention(embed_dim, num_heads)

        output = model(embedded_output)
        print(output.size())  # Output shape should be (2, 10, embed_dim)

本层我们从 LLM modeling.py 中将 Atention 的核心部分迁移到 BITDDDAtention Class 中:

class BITDDDAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super(BITDDDAttention, self).__init__()

        self.embed_dim = embed_dim  # embedding 维度
        self.num_heads = num_heads  # head 数量
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads

        # 构建 Q/K/V 向量以及最后的全连接 MLP
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
        query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Compute the attention scores
        attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
        attention_probs = torch.softmax(attention_scores, dim=-1)

        # Apply the attention weights to the value
        attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)

        # Apply a linear layer to the output
        x = self.fc_out(attention_output)
        return x

下面我们逐行看下 Mutil-Head Attention 的执行流程与维度变化:

- Size

batch_size, seq_len, _ = x.size()

这一步解析 Attention 层输入的 batch 样本的 bsz、seq_len,由于 init 方法中已经给出了 emd_dim,所以这里使用 '_' 忽略。

- Q/K/V 获取

# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

self.query、key、value 都是 nn.Linear(embed_dim, embed_dim) 的线性转换层,Q/K/V 的处理逻辑是相同的,这里通过 view 即 resize 方法将线性转换后的向量 (bsz, seq_len, embed_dim) 转换为 (bsz, seq_len, num_heads, head_dim),最后通过 permute 交换位置得到 (bsz, num_heads, seq_len, head_dim) 的输出向量,用于后续 multi-head 的计算。这里通过 assert 判断是否整除:

self.head_dim = embed_dim // num_heads

根据上面 init 给出的 heads 以及 embed_dim,可以得到最终维度为: (2, 4, 10, 192):

tensor([[[[ 0.0185, -0.1872,  0.1827,  ..., -0.7914, -0.0074, -0.6228],
          [-0.1390,  0.4675,  0.0325,  ...,  0.0187,  0.0912, -0.2692],
          [-0.1342,  0.2904, -0.2637,  ...,  0.1130, -0.0226, -0.3510],
          ...,
          [-0.1151, -0.2627, -0.6453,  ...,  0.4885, -0.1982, -0.1538],
          [-0.1281,  0.2321, -0.0815,  ..., -0.1740,  0.4909, -0.1373],
          [-0.1364,  0.2844, -0.0728,  ..., -0.0620,  0.3605, -0.2292]],

         [[ 0.3641,  0.1707, -0.0567,  ...,  0.0267,  0.3272,  0.1560],
          [-0.1206,  0.6853,  0.0990,  ..., -0.0875,  0.2414,  0.5490],
          [-0.4080,  0.0679,  0.3174,  ...,  0.0970, -0.0127,  0.1664],
          ...,
          [-0.2878,  0.2856,  0.0777,  ..., -0.0791,  0.0847,  0.0545],
          [ 0.2381, -0.1032,  0.2887,  ...,  0.2219,  0.2837,  0.0345],
          [ 0.1421, -0.0956,  0.1983,  ...,  0.1784,  0.1827,  0.0776]],

         [[-0.2031, -0.2496, -0.0072,  ..., -0.1553, -0.0441,  0.0200],
          [-0.2028, -0.4097,  0.1779,  ...,  0.0333, -0.4005, -0.3453],
          [ 0.0926, -0.1818,  0.0492,  ...,  0.3059, -0.6175, -0.2858],
          ...,
          [ 0.3494, -0.4813,  0.7086,  ...,  0.6181,  0.1515, -0.1279],
          [-0.0542,  0.3148,  0.0172,  ...,  0.0037, -0.2878, -0.1582],
          [-0.1381,  0.2450,  0.0490,  ..., -0.0824, -0.2504, -0.2464]],

         [[ 0.6905, -0.1202,  0.6489,  ...,  0.6069,  0.2634, -0.0595],
          [ 0.3937, -0.2795,  0.7692,  ...,  0.1321, -0.0240, -0.1484],
          [ 0.2260, -0.4332,  0.4651,  ..., -0.1797, -0.1127, -0.3294],
          ...,
          [ 0.0168, -0.2892,  0.4032,  ..., -0.4515,  0.3833, -0.7699],
          [ 0.1970, -0.3264,  0.4196,  ...,  0.3044, -0.0819, -0.2083],
          [ 0.2492, -0.3419,  0.5813,  ...,  0.1855, -0.2431, -0.1149]]],


        [[[ 0.0225, -0.2359,  0.0754,  ..., -0.7577,  0.0936, -0.6233],
          [ 0.0479,  0.5459, -0.3047,  ..., -0.3134,  0.0416,  0.0397],
          [ 0.1172,  0.2506, -0.5461,  ...,  0.1287, -0.0441, -0.2074],
          ...,
          [-0.3586, -0.3827, -0.6436,  ...,  0.3915, -0.2485, -0.1576],
          [-0.1502,  0.1852, -0.1007,  ..., -0.1310,  0.5079, -0.1868],
          [-0.1622,  0.2055, -0.1428,  ..., -0.0887,  0.3516, -0.2383]],

         [[ 0.4338,  0.2326, -0.0661,  ...,  0.0309,  0.3088,  0.1711],
          [-0.4011,  0.9250,  0.2983,  ..., -0.4108,  0.4223,  0.6880],
          [-0.2721,  0.4383,  0.6376,  ..., -0.0888, -0.0647, -0.0073],
          ...,
          [ 0.1742,  0.2020, -0.1020,  ..., -0.1444,  0.2459,  0.1079],
          [ 0.2608, -0.0978,  0.2557,  ...,  0.2132,  0.2125,  0.0010],
          [ 0.1041, -0.1335,  0.1523,  ...,  0.1797,  0.1323,  0.0036]],

         [[-0.1826, -0.2200, -0.0026,  ..., -0.1664, -0.0773,  0.0607],
          [-0.1257, -0.2642,  0.6933,  ...,  0.4202, -0.1153, -0.3960],
          [-0.1353, -0.4837,  0.3527,  ...,  0.3592, -0.5616, -0.3685],
          ...,
          [ 0.6056, -0.3298,  0.7872,  ...,  0.3984,  0.4775,  0.2213],
          [-0.1211,  0.3394, -0.0247,  ...,  0.0251, -0.3108, -0.1656],
          [-0.1572,  0.3040,  0.0164,  ..., -0.1026, -0.2737, -0.2175]],

         [[ 0.7236, -0.1187,  0.6491,  ...,  0.6230,  0.2401, -0.0061],
          [-0.0402, -0.0318,  0.7717,  ..., -0.0389,  0.1465, -0.3047],
          [ 0.2734, -0.4473,  0.6278,  ..., -0.3827, -0.0412, -0.7133],
          ...,
          [-0.0418,  0.0670,  0.1462,  ..., -0.6109,  0.4838, -0.4277],
          [ 0.2340, -0.3250,  0.4256,  ...,  0.3217, -0.0688, -0.1837],
          [ 0.1940, -0.2878,  0.5281,  ...,  0.2155, -0.1810, -0.0649]]]])

- Attention Score 计算

# Compute the attention scores
attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
attention_probs = torch.softmax(attention_scores, dim=-1)

Attention Score 的计算依赖于 Q/K,这里把 key 的维度通过 permute 做了转换,由  (bsz, num_heads, seq_len, head_dim) 变换为  (bsz, num_heads, head_dim, seq_len),matmul 相乘后得到 attention_scores 的维度为 (bsz, num_heads, seq_len, seq_len) 即 (2, 4, 10, 10),除以 sqrt(head_dim) 是在应用 scale_dot 防止 matmul 的乘积过大,而最后 softmax(dim=-1) 则将 Attention Score 的最后一维的 10 个数字进行了归一化:

tensor([[[[0.1188, 0.1000, 0.0912, 0.0963, 0.0984, 0.0862, 0.1004, 0.0955,
           0.1040, 0.1093],
          [0.1077, 0.0949, 0.0963, 0.1035, 0.0961, 0.0940, 0.0946, 0.0890,
           0.1090, 0.1149],
          [0.0965, 0.1010, 0.0930, 0.0972, 0.1032, 0.1031, 0.0989, 0.0982,
           0.1062, 0.1026],
          [0.0932, 0.1033, 0.0970, 0.0977, 0.0961, 0.1050, 0.0990, 0.1082,
           0.1006, 0.0999],
          [0.0947, 0.1033, 0.0949, 0.0945, 0.0957, 0.1036, 0.0964, 0.0985,
           0.1083, 0.1102],
          [0.0941, 0.1026, 0.0939, 0.0942, 0.0953, 0.1008, 0.1001, 0.1089,
           0.1038, 0.1063],
          [0.1017, 0.1019, 0.1008, 0.0937, 0.1095, 0.1007, 0.0913, 0.0832,
           0.1092, 0.1079],
          [0.0926, 0.1121, 0.1009, 0.0991, 0.0955, 0.1017, 0.0964, 0.1030,
           0.1004, 0.0982],
          [0.1010, 0.1021, 0.0948, 0.0954, 0.0976, 0.1024, 0.0916, 0.1032,
           0.1049, 0.1071],
          [0.1046, 0.1077, 0.0932, 0.0948, 0.1006, 0.1002, 0.0934, 0.0983,
           0.1042, 0.1032]],

         ......         

         [[0.1047, 0.0999, 0.1045, 0.1054, 0.0979, 0.1071, 0.0863, 0.0884,
           0.1012, 0.1046],
          [0.1019, 0.1113, 0.1010, 0.0990, 0.0981, 0.1060, 0.0872, 0.0915,
           0.1017, 0.1022],
          [0.0977, 0.0996, 0.0993, 0.1027, 0.0970, 0.0985, 0.0977, 0.0990,
           0.1069, 0.1018],
          [0.1042, 0.1125, 0.1049, 0.1022, 0.0981, 0.0950, 0.0864, 0.0876,
           0.1045, 0.1046],
          [0.0987, 0.1151, 0.1018, 0.0956, 0.0923, 0.0955, 0.0938, 0.0910,
           0.1073, 0.1087],
          [0.0985, 0.1143, 0.0936, 0.1029, 0.0954, 0.1028, 0.0857, 0.0901,
           0.1076, 0.1092],
          [0.0961, 0.0908, 0.1013, 0.1055, 0.0992, 0.1035, 0.0919, 0.0971,
           0.1090, 0.1056],
          [0.0874, 0.0935, 0.1012, 0.1057, 0.1044, 0.0968, 0.0936, 0.0950,
           0.1132, 0.1091],
          [0.1041, 0.1118, 0.0958, 0.0968, 0.0971, 0.1044, 0.0925, 0.0906,
           0.1028, 0.1041],
          [0.1028, 0.1123, 0.0971, 0.1005, 0.1013, 0.1020, 0.0896, 0.0894,
           0.1026, 0.1023]]]])

- Attention Output 

# Apply the attention weights to the value
attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)

Attention Probs 的维度为  ​​(2, 4, 10, 10) ,value 的维度为 (2, 4, 10, 192),相乘后得到 (2, 4, 10, 192) 即 (bsz, num_heads, seq_len, head_dim),通过 permute 转换为 (bsz, seq_len, num_heads, head_dim),再通过 view 将后两维 num_heads x head_dim 合并为 d_model,从而最终 attention_output 的维度为 (bsz, seq_len, d_model) 与原始 token_ids 通过 Embedding 层映射后的向量维度保持一致。

tensor([[[-0.0421,  0.0127, -0.3383,  ...,  0.1617, -0.2079, -0.3181],
         [-0.0401,  0.0172, -0.3488,  ...,  0.1581, -0.2098, -0.3260],
         [-0.0397,  0.0130, -0.3420,  ...,  0.1581, -0.2081, -0.3266],
         ...,
         [-0.0428,  0.0137, -0.3372,  ...,  0.1589, -0.2145, -0.3305],
         [-0.0424,  0.0120, -0.3419,  ...,  0.1574, -0.2198, -0.3344],
         [-0.0426,  0.0140, -0.3463,  ...,  0.1562, -0.2191, -0.3338]],

        [[ 0.0502,  0.0926, -0.3300,  ...,  0.1376, -0.1264, -0.3689],
         [ 0.0369,  0.0917, -0.3339,  ...,  0.1439, -0.1089, -0.3571],
         [ 0.0419,  0.0915, -0.3328,  ...,  0.1480, -0.1168, -0.3654],
         ...,
         [ 0.0438,  0.0946, -0.3290,  ...,  0.1435, -0.1302, -0.3702],
         [ 0.0417,  0.0898, -0.3281,  ...,  0.1358, -0.1374, -0.3759],
         [ 0.0428,  0.0906, -0.3306,  ...,  0.1345, -0.1330, -0.3752]]])

- Linear 浅层 MLP

# Apply a linear layer to the output
x = self.fc_out(attention_output)

fc_out 的维度是 nn.Linear(embed_dim, embed_dim),所有 attention_output 经过处理后 (bsz, seq_len, d_model) x (d_model, d_model) = (bsz, seq_len, d_model)。

tensor([[[ 1.0228e-01,  1.6250e-01, -1.4914e-01,  ..., -1.7511e-01,
          -2.1751e-03, -2.0877e-02],
         [ 9.9930e-02,  1.6427e-01, -1.4394e-01,  ..., -1.7894e-01,
           1.9605e-03, -2.4290e-02],
         [ 1.0188e-01,  1.6577e-01, -1.4313e-01,  ..., -1.7274e-01,
           5.3616e-03, -1.8874e-02],
         ...,
         [ 1.0584e-01,  1.6541e-01, -1.4315e-01,  ..., -1.7077e-01,
          -4.8522e-04, -2.2207e-02],
         [ 1.0028e-01,  1.6638e-01, -1.3908e-01,  ..., -1.7138e-01,
          -4.0303e-05, -2.2604e-02],
         [ 1.0054e-01,  1.6448e-01, -1.4135e-01,  ..., -1.7086e-01,
           2.8514e-03, -1.9951e-02]],

        [[ 4.9912e-02,  1.3306e-01, -1.2705e-01,  ..., -1.2117e-01,
           3.5498e-02,  3.8191e-03],
         [ 4.8556e-02,  1.3361e-01, -1.2207e-01,  ..., -1.2270e-01,
           3.7410e-02, -3.3710e-03],
         [ 4.9592e-02,  1.3507e-01, -1.2446e-01,  ..., -1.2247e-01,
           4.3996e-02,  2.0591e-03],
         ...,
         [ 5.2688e-02,  1.3105e-01, -1.2519e-01,  ..., -1.1373e-01,
           3.7038e-02,  2.5118e-03],
         [ 4.8786e-02,  1.3443e-01, -1.1793e-01,  ..., -1.1811e-01,
           3.4455e-02,  3.0611e-04],
         [ 4.7252e-02,  1.3401e-01, -1.1889e-01,  ..., -1.1601e-01,
           3.6708e-02,  2.7476e-03]]])

4.完整代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer


class BITDDDAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super(BITDDDAttention, self).__init__()

        self.embed_dim = embed_dim  # embedding 维度
        self.num_heads = num_heads  # head 数量
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads

        # 构建 Q/K/V 向量以及最后的全连接 MLP
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
        query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Compute the attention scores
        attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
        attention_probs = torch.softmax(attention_scores, dim=-1)

        # Apply the attention weights to the value
        attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)

        # Apply a linear layer to the output
        x = self.fc_out(attention_output)
        return x


if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    pretrained_bert = BertModel.from_pretrained('bert-base-uncased')

    input_texts = ["This is a test sentence.", "Here is another test sentence."]
    input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,
                                  return_tensors='pt') for text in input_texts]
    input_ids = torch.cat(input_ids, dim=0)  # Concatenate and add batch dimension

    with torch.no_grad():
        embedded_output = pretrained_bert(input_ids)[0]  # Get the output of the BERT model
        print(embedded_output.size())  # Output shape should be (2, 10, embedding_dim)

        embed_dim = embedded_output.size(-1)
        num_heads = 4
        model = BITDDDAttention(embed_dim, num_heads)

        output = model(embedded_output)
        print(output.size())  # Output shape should be (2, 10, embed_dim)

四.总结

上述代码可以在本地 CPU/GPU 环境跑起来,大家可以自己打断点熟悉整个过程维度的变化,计算的流程,Multi-Head Attention 分多个 head 计算不同 token 的注意力权重并加权求和,对于 Decoder-Only 的架构,其还会添加 Causal Mask 保证前面的文字看不到后面的文字。本文先介绍到 Transformer 的输出,后续我们介绍如何通过 Transformer 最后一层 lm_head 的输出计算 next_token 的概率并计算交叉熵 loss。

Logo

尧米是由西云算力与CSDN联合运营的AI算力和模型开源社区品牌,为基于DaModel智算平台的AI应用企业和泛AI开发者提供技术交流与成果转化平台。

更多推荐