LLM - Transformer && Multi-Head Attention 维度变化与源码详解
Transformer && Multi-Head Attention 维度变化与源码详解。
一.引言
前面我们基于 LLM 大模型源码介绍了 Causal Mask 以及 ROPE 旋转位置编码的实现,本文介绍源码中 Transformer 的实现流程,我们基于代码逐行分析维度变化与代码含义,希望能够清晰的了解 LLM 中 Transformer 运行的流程。
二.Transformer 分层维度
上面这个 Transformer 的基础结构我们在之前已经提到过很多次,这里结合维度变化再啰嗦一次,更详细的介绍可以参考: LLM - Transformer && LLaMA2 结构分析与 LoRA 详解。
1.单条样本
- Embedding Layer
对于一个典型的 LLM 大模型,输入 Embedding 层的维度 d_model 通常指的是将输入的标记 token 通过一个 embedding 层映射转换为连续向量的维度。例如,在 BERT-base 模型中,d_model 是 768,而在当下大模型中 d_model 为 8192。
- Transformer Layer
Transformer 层的输出维度通常和输入 Embedding 层的维度一致,即 d_model。如果我们持续使用 BERT-base 的例子,那么每个 Transformer 层 [ BERT中称为encoder层,LLM 中多为 decoder 层 ] 的输出也将是维度为 768 / 8192 的向量。
- lm_head Layer
最后的 lm_head(语言模型头)的维度通常等于词汇表的大小 vocab_size,因为 lm_head 的作用是将 Transformer 层的输出转换成每个词汇的概率分布。举例来说,如果模型处理的语言的词汇表大小为 30000 个单词,那么 lm_head 的输出维度就是 30000。
- hidden_states
hidden_states 是 Transformer 模型处理过程中的一个术语,常见于模型的中间输出和内部分析。其记录了隐层的激活值,对于每个输入标记 token,Transformer 的每个层都会有一个输出向量,它表示的是在该层的特定深度上输入的表示。对于一个 N 层堆叠的 Transformer 模型,对于一个给定的输入序列,模型将会有 N 个这样的隐藏状态集。其中每个隐藏状态也会包含注意力分布,这是 Transformer 的自注意力机制的一个关键组成部分,它允许模型在处理输入时衡量不同部分之间的相互依赖性。
Tips:
假设我们有一个 BERT-base 模型,它使用 12 层 Transformer,每层的输出维度为 768,若输入一个有 5 个 tokens 的序列,每个 token 会首先被转换成一个 768 维的 embedding 向量。因此,hidden_states 在模型刚开始时会是一个形状为 (5, 768) 的张量。经过 12 层 Transformer 层处理后最后输出的 hidden_states 将会是一个形状为 (12, 5, 768) 的 3 维张量,其中包含了序列中每个token 在各个层上的表征。
2.批次样本
上面给出了单条样本的转换流程,下面我们分析下 batch_size 情况下维度的变换流程。假设我们有一个 BERT-base 模型:
词汇表大小 vocab_size = 30000
嵌入层维度 d_model = 768
堆叠层数量 N = 12
最大序列长度 max_seq_length = 128
批次大小 batch_size = 32
以下是数据通过模型时维度的具体变化过程:
- Input Layer
输入层维度为 (batch_size, max_seq_length) 即 (32, 128),每一个 128 的张量表示批次中每个序列的 token_id,即 text 通过 tokenizer 处理后的结果。
- Embedding Layer
(bsz, max_seq_length) 的整数张量会被送入 Embedding 层,以 Bert 为例,其会被映射到 (bsz, max_seq_length, d_model) 的维度,即 (32, 128, 768)。这表明我们现在有 32 条样本,每个序列有 128 个 768 维的词嵌入向量。
- Transformer Layer
每个 Transformer 层接受一个 (bsz, max_seq_length, d_model) 的张量,经过 multi_head_attention 后输出一个相同形状的张量,这是因为 transformer 层通常会保持输入输出的维度相同,因此经过本层映射后,维度依然为 (bsz, max_seq_length, d_model) 即 (32, 128, 768)。
- lm_head Layer
lm_head 线性层将 Transformer 层的输出 (bsz, max_seq_length, d_model) 转换为 (bsz, max_seq_length, vocab_size) 的张量,即 (32, 128, 30000)。这一层一般是通过 Linear 实现的,对于复杂的 LLM,还会有 MLP 层,但最终 lm_head 的目的都是将 d_model 映射到 vocab_size,即生成一个与词汇表大小匹配的权重矩阵,代表每个 token 可能性的分布。
Tips:
如果考虑中间的 hidden_states,那么对于序列中的每个 token,在每个 Transformer 层中,我们都会得到一个 768 维的向量。因此,对于整个 batch 来说,每一层的 hidden_states 的形状为(batch_size, max_seq_length, d_model),即 (32, 128, 768)。如果我们保存所有层的hidden_states,那么我们就得到了一个形状为 (num_layers, batch_size, max_seq_length, d_model) 的 4 维张量,即 (12, 32, 128, 768),这里 num_layers 就是前面提到的 N,即 LLM 中 transformer 层堆叠的数量,这样,你就可以看到不同维度如何随着数据流通过模型而变化。这里需要注意的是真实情况下由于序列化长度可能不同,还会涉及到填充 padding 和掩码 masking 来确保批量处理是有效的,然而这并不影响上述维度变化的基本流程。
三.Transformer 维度变换
为了大家可以在本机 debug 快速测试,下面的示例我们以 Bert 及其 tokenizer 作为基模型构建 token_id 以及 Embedding,后续的 Multi-Head Attention 我们基于 Qwen 的逻辑进行了迁移,保持主体实现风格不变,更完整的代码可以参考 HF 上 modeling.py。
1.Input Layer
输入层以及嵌入层我们通过 Bert 模型的 tokenizer 获取:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
if __name__ == '__main__':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
pretrained_bert = BertModel.from_pretrained('bert-base-uncased')
input_texts = ["This is a test sentence.", "Here is another test sentence."]
input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,
return_tensors='pt') for text in input_texts]
input_ids = torch.cat(input_ids, dim=0) # Concatenate and add batch dimension
为了方便我们 input_texts 构造两条样本,所以 bsz = 2、max_length = 10、d_model = 768,input_ids 维度为 (10, ):
通过 concat 得到 (bsz, max_length) = (2, 10) 的初始维度:
tensor([[ 101, 2023, 2003, 1037, 3231, 6251, 1012, 102, 0, 0],
[ 101, 2182, 2003, 2178, 3231, 6251, 1012, 102, 0, 0]])
2.Embedding Layer
with torch.no_grad():
embedded_output = pretrained_bert(input_ids)[0] # Get the output of the BERT model
print(embedded_output.size()) # Output shape should be (2, 10, embedding_dim)
这里通过 bert 的 Embedding 层获取 input_id 对应的 Embedding,由于 d_model = 768,所以前面 token_id 的 (bsz, max_length) 转换为 (bsz, max_length, d_model) 即 (2, 10, 768):
tensor([[[-3.7545e-02, 5.3234e-04, -1.3553e-02, ..., -1.9545e-01,
2.3569e-01, 4.7479e-01],
[-7.1746e-01, -2.8763e-01, 1.4100e-01, ..., -5.5593e-01,
6.1830e-01, 3.9255e-01],
[-1.9318e-01, -4.0202e-01, 3.2924e-01, ..., -1.5206e-01,
3.4014e-01, 1.0233e+00],
...,
[ 1.5273e-01, 1.1651e-01, 1.5754e-01, ..., 6.9833e-02,
-8.5732e-01, -4.3875e-02],
[ 7.0679e-02, -2.3521e-01, 6.1713e-01, ..., -7.3852e-02,
2.5070e-01, -6.3240e-02],
[-1.3249e-01, -3.6026e-01, 3.5025e-01, ..., -5.5981e-02,
1.0420e-01, -4.3954e-01]],
[[-2.9592e-02, -1.4164e-01, -2.2295e-03, ..., -1.3087e-01,
2.9421e-01, 5.5132e-01],
[-1.0146e+00, -6.8757e-01, 1.9959e-01, ..., -4.2000e-01,
1.7332e-01, 9.2754e-02],
[-1.3425e-01, -8.1044e-01, 2.6674e-01, ..., 4.6978e-02,
-1.0026e-01, 4.5293e-01],
...,
[ 4.5527e-01, 2.2234e-02, -3.6816e-01, ..., 4.3154e-01,
-8.6396e-01, -2.8542e-01],
[ 1.4188e-01, -2.4001e-01, 6.5681e-01, ..., -5.7224e-02,
3.1025e-01, -9.0286e-02],
[-3.9205e-02, -3.2815e-01, 4.7910e-01, ..., -4.7641e-02,
2.9916e-02, -4.5328e-01]]])
3.Multi-Head Attention
embed_dim = embedded_output.size(-1)
num_heads = 4
model = BITDDDAttention(embed_dim, num_heads)
output = model(embedded_output)
print(output.size()) # Output shape should be (2, 10, embed_dim)
本层我们从 LLM modeling.py 中将 Atention 的核心部分迁移到 BITDDDAtention Class 中:
class BITDDDAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(BITDDDAttention, self).__init__()
self.embed_dim = embed_dim # embedding 维度
self.num_heads = num_heads # head 数量
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
# 构建 Q/K/V 向量以及最后的全连接 MLP
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.fc_out = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Compute the attention scores
attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
attention_probs = torch.softmax(attention_scores, dim=-1)
# Apply the attention weights to the value
attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)
# Apply a linear layer to the output
x = self.fc_out(attention_output)
return x
下面我们逐行看下 Mutil-Head Attention 的执行流程与维度变化:
- Size
batch_size, seq_len, _ = x.size()
这一步解析 Attention 层输入的 batch 样本的 bsz、seq_len,由于 init 方法中已经给出了 emd_dim,所以这里使用 '_' 忽略。
- Q/K/V 获取
# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
self.query、key、value 都是 nn.Linear(embed_dim, embed_dim) 的线性转换层,Q/K/V 的处理逻辑是相同的,这里通过 view 即 resize 方法将线性转换后的向量 (bsz, seq_len, embed_dim) 转换为 (bsz, seq_len, num_heads, head_dim),最后通过 permute 交换位置得到 (bsz, num_heads, seq_len, head_dim) 的输出向量,用于后续 multi-head 的计算。这里通过 assert 判断是否整除:
self.head_dim = embed_dim // num_heads
根据上面 init 给出的 heads 以及 embed_dim,可以得到最终维度为: (2, 4, 10, 192):
tensor([[[[ 0.0185, -0.1872, 0.1827, ..., -0.7914, -0.0074, -0.6228],
[-0.1390, 0.4675, 0.0325, ..., 0.0187, 0.0912, -0.2692],
[-0.1342, 0.2904, -0.2637, ..., 0.1130, -0.0226, -0.3510],
...,
[-0.1151, -0.2627, -0.6453, ..., 0.4885, -0.1982, -0.1538],
[-0.1281, 0.2321, -0.0815, ..., -0.1740, 0.4909, -0.1373],
[-0.1364, 0.2844, -0.0728, ..., -0.0620, 0.3605, -0.2292]],
[[ 0.3641, 0.1707, -0.0567, ..., 0.0267, 0.3272, 0.1560],
[-0.1206, 0.6853, 0.0990, ..., -0.0875, 0.2414, 0.5490],
[-0.4080, 0.0679, 0.3174, ..., 0.0970, -0.0127, 0.1664],
...,
[-0.2878, 0.2856, 0.0777, ..., -0.0791, 0.0847, 0.0545],
[ 0.2381, -0.1032, 0.2887, ..., 0.2219, 0.2837, 0.0345],
[ 0.1421, -0.0956, 0.1983, ..., 0.1784, 0.1827, 0.0776]],
[[-0.2031, -0.2496, -0.0072, ..., -0.1553, -0.0441, 0.0200],
[-0.2028, -0.4097, 0.1779, ..., 0.0333, -0.4005, -0.3453],
[ 0.0926, -0.1818, 0.0492, ..., 0.3059, -0.6175, -0.2858],
...,
[ 0.3494, -0.4813, 0.7086, ..., 0.6181, 0.1515, -0.1279],
[-0.0542, 0.3148, 0.0172, ..., 0.0037, -0.2878, -0.1582],
[-0.1381, 0.2450, 0.0490, ..., -0.0824, -0.2504, -0.2464]],
[[ 0.6905, -0.1202, 0.6489, ..., 0.6069, 0.2634, -0.0595],
[ 0.3937, -0.2795, 0.7692, ..., 0.1321, -0.0240, -0.1484],
[ 0.2260, -0.4332, 0.4651, ..., -0.1797, -0.1127, -0.3294],
...,
[ 0.0168, -0.2892, 0.4032, ..., -0.4515, 0.3833, -0.7699],
[ 0.1970, -0.3264, 0.4196, ..., 0.3044, -0.0819, -0.2083],
[ 0.2492, -0.3419, 0.5813, ..., 0.1855, -0.2431, -0.1149]]],
[[[ 0.0225, -0.2359, 0.0754, ..., -0.7577, 0.0936, -0.6233],
[ 0.0479, 0.5459, -0.3047, ..., -0.3134, 0.0416, 0.0397],
[ 0.1172, 0.2506, -0.5461, ..., 0.1287, -0.0441, -0.2074],
...,
[-0.3586, -0.3827, -0.6436, ..., 0.3915, -0.2485, -0.1576],
[-0.1502, 0.1852, -0.1007, ..., -0.1310, 0.5079, -0.1868],
[-0.1622, 0.2055, -0.1428, ..., -0.0887, 0.3516, -0.2383]],
[[ 0.4338, 0.2326, -0.0661, ..., 0.0309, 0.3088, 0.1711],
[-0.4011, 0.9250, 0.2983, ..., -0.4108, 0.4223, 0.6880],
[-0.2721, 0.4383, 0.6376, ..., -0.0888, -0.0647, -0.0073],
...,
[ 0.1742, 0.2020, -0.1020, ..., -0.1444, 0.2459, 0.1079],
[ 0.2608, -0.0978, 0.2557, ..., 0.2132, 0.2125, 0.0010],
[ 0.1041, -0.1335, 0.1523, ..., 0.1797, 0.1323, 0.0036]],
[[-0.1826, -0.2200, -0.0026, ..., -0.1664, -0.0773, 0.0607],
[-0.1257, -0.2642, 0.6933, ..., 0.4202, -0.1153, -0.3960],
[-0.1353, -0.4837, 0.3527, ..., 0.3592, -0.5616, -0.3685],
...,
[ 0.6056, -0.3298, 0.7872, ..., 0.3984, 0.4775, 0.2213],
[-0.1211, 0.3394, -0.0247, ..., 0.0251, -0.3108, -0.1656],
[-0.1572, 0.3040, 0.0164, ..., -0.1026, -0.2737, -0.2175]],
[[ 0.7236, -0.1187, 0.6491, ..., 0.6230, 0.2401, -0.0061],
[-0.0402, -0.0318, 0.7717, ..., -0.0389, 0.1465, -0.3047],
[ 0.2734, -0.4473, 0.6278, ..., -0.3827, -0.0412, -0.7133],
...,
[-0.0418, 0.0670, 0.1462, ..., -0.6109, 0.4838, -0.4277],
[ 0.2340, -0.3250, 0.4256, ..., 0.3217, -0.0688, -0.1837],
[ 0.1940, -0.2878, 0.5281, ..., 0.2155, -0.1810, -0.0649]]]])
- Attention Score 计算
# Compute the attention scores
attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
attention_probs = torch.softmax(attention_scores, dim=-1)
Attention Score 的计算依赖于 Q/K,这里把 key 的维度通过 permute 做了转换,由 (bsz, num_heads, seq_len, head_dim) 变换为 (bsz, num_heads, head_dim, seq_len),matmul 相乘后得到 attention_scores 的维度为 (bsz, num_heads, seq_len, seq_len) 即 (2, 4, 10, 10),除以 sqrt(head_dim) 是在应用 scale_dot 防止 matmul 的乘积过大,而最后 softmax(dim=-1) 则将 Attention Score 的最后一维的 10 个数字进行了归一化:
tensor([[[[0.1188, 0.1000, 0.0912, 0.0963, 0.0984, 0.0862, 0.1004, 0.0955,
0.1040, 0.1093],
[0.1077, 0.0949, 0.0963, 0.1035, 0.0961, 0.0940, 0.0946, 0.0890,
0.1090, 0.1149],
[0.0965, 0.1010, 0.0930, 0.0972, 0.1032, 0.1031, 0.0989, 0.0982,
0.1062, 0.1026],
[0.0932, 0.1033, 0.0970, 0.0977, 0.0961, 0.1050, 0.0990, 0.1082,
0.1006, 0.0999],
[0.0947, 0.1033, 0.0949, 0.0945, 0.0957, 0.1036, 0.0964, 0.0985,
0.1083, 0.1102],
[0.0941, 0.1026, 0.0939, 0.0942, 0.0953, 0.1008, 0.1001, 0.1089,
0.1038, 0.1063],
[0.1017, 0.1019, 0.1008, 0.0937, 0.1095, 0.1007, 0.0913, 0.0832,
0.1092, 0.1079],
[0.0926, 0.1121, 0.1009, 0.0991, 0.0955, 0.1017, 0.0964, 0.1030,
0.1004, 0.0982],
[0.1010, 0.1021, 0.0948, 0.0954, 0.0976, 0.1024, 0.0916, 0.1032,
0.1049, 0.1071],
[0.1046, 0.1077, 0.0932, 0.0948, 0.1006, 0.1002, 0.0934, 0.0983,
0.1042, 0.1032]],
......
[[0.1047, 0.0999, 0.1045, 0.1054, 0.0979, 0.1071, 0.0863, 0.0884,
0.1012, 0.1046],
[0.1019, 0.1113, 0.1010, 0.0990, 0.0981, 0.1060, 0.0872, 0.0915,
0.1017, 0.1022],
[0.0977, 0.0996, 0.0993, 0.1027, 0.0970, 0.0985, 0.0977, 0.0990,
0.1069, 0.1018],
[0.1042, 0.1125, 0.1049, 0.1022, 0.0981, 0.0950, 0.0864, 0.0876,
0.1045, 0.1046],
[0.0987, 0.1151, 0.1018, 0.0956, 0.0923, 0.0955, 0.0938, 0.0910,
0.1073, 0.1087],
[0.0985, 0.1143, 0.0936, 0.1029, 0.0954, 0.1028, 0.0857, 0.0901,
0.1076, 0.1092],
[0.0961, 0.0908, 0.1013, 0.1055, 0.0992, 0.1035, 0.0919, 0.0971,
0.1090, 0.1056],
[0.0874, 0.0935, 0.1012, 0.1057, 0.1044, 0.0968, 0.0936, 0.0950,
0.1132, 0.1091],
[0.1041, 0.1118, 0.0958, 0.0968, 0.0971, 0.1044, 0.0925, 0.0906,
0.1028, 0.1041],
[0.1028, 0.1123, 0.0971, 0.1005, 0.1013, 0.1020, 0.0896, 0.0894,
0.1026, 0.1023]]]])
- Attention Output
# Apply the attention weights to the value
attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)
Attention Probs 的维度为 (2, 4, 10, 10) ,value 的维度为 (2, 4, 10, 192),相乘后得到 (2, 4, 10, 192) 即 (bsz, num_heads, seq_len, head_dim),通过 permute 转换为 (bsz, seq_len, num_heads, head_dim),再通过 view 将后两维 num_heads x head_dim 合并为 d_model,从而最终 attention_output 的维度为 (bsz, seq_len, d_model) 与原始 token_ids 通过 Embedding 层映射后的向量维度保持一致。
tensor([[[-0.0421, 0.0127, -0.3383, ..., 0.1617, -0.2079, -0.3181],
[-0.0401, 0.0172, -0.3488, ..., 0.1581, -0.2098, -0.3260],
[-0.0397, 0.0130, -0.3420, ..., 0.1581, -0.2081, -0.3266],
...,
[-0.0428, 0.0137, -0.3372, ..., 0.1589, -0.2145, -0.3305],
[-0.0424, 0.0120, -0.3419, ..., 0.1574, -0.2198, -0.3344],
[-0.0426, 0.0140, -0.3463, ..., 0.1562, -0.2191, -0.3338]],
[[ 0.0502, 0.0926, -0.3300, ..., 0.1376, -0.1264, -0.3689],
[ 0.0369, 0.0917, -0.3339, ..., 0.1439, -0.1089, -0.3571],
[ 0.0419, 0.0915, -0.3328, ..., 0.1480, -0.1168, -0.3654],
...,
[ 0.0438, 0.0946, -0.3290, ..., 0.1435, -0.1302, -0.3702],
[ 0.0417, 0.0898, -0.3281, ..., 0.1358, -0.1374, -0.3759],
[ 0.0428, 0.0906, -0.3306, ..., 0.1345, -0.1330, -0.3752]]])
- Linear 浅层 MLP
# Apply a linear layer to the output
x = self.fc_out(attention_output)
fc_out 的维度是 nn.Linear(embed_dim, embed_dim),所有 attention_output 经过处理后 (bsz, seq_len, d_model) x (d_model, d_model) = (bsz, seq_len, d_model)。
tensor([[[ 1.0228e-01, 1.6250e-01, -1.4914e-01, ..., -1.7511e-01,
-2.1751e-03, -2.0877e-02],
[ 9.9930e-02, 1.6427e-01, -1.4394e-01, ..., -1.7894e-01,
1.9605e-03, -2.4290e-02],
[ 1.0188e-01, 1.6577e-01, -1.4313e-01, ..., -1.7274e-01,
5.3616e-03, -1.8874e-02],
...,
[ 1.0584e-01, 1.6541e-01, -1.4315e-01, ..., -1.7077e-01,
-4.8522e-04, -2.2207e-02],
[ 1.0028e-01, 1.6638e-01, -1.3908e-01, ..., -1.7138e-01,
-4.0303e-05, -2.2604e-02],
[ 1.0054e-01, 1.6448e-01, -1.4135e-01, ..., -1.7086e-01,
2.8514e-03, -1.9951e-02]],
[[ 4.9912e-02, 1.3306e-01, -1.2705e-01, ..., -1.2117e-01,
3.5498e-02, 3.8191e-03],
[ 4.8556e-02, 1.3361e-01, -1.2207e-01, ..., -1.2270e-01,
3.7410e-02, -3.3710e-03],
[ 4.9592e-02, 1.3507e-01, -1.2446e-01, ..., -1.2247e-01,
4.3996e-02, 2.0591e-03],
...,
[ 5.2688e-02, 1.3105e-01, -1.2519e-01, ..., -1.1373e-01,
3.7038e-02, 2.5118e-03],
[ 4.8786e-02, 1.3443e-01, -1.1793e-01, ..., -1.1811e-01,
3.4455e-02, 3.0611e-04],
[ 4.7252e-02, 1.3401e-01, -1.1889e-01, ..., -1.1601e-01,
3.6708e-02, 2.7476e-03]]])
4.完整代码
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
class BITDDDAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(BITDDDAttention, self).__init__()
self.embed_dim = embed_dim # embedding 维度
self.num_heads = num_heads # head 数量
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
# 构建 Q/K/V 向量以及最后的全连接 MLP
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.fc_out = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Compute the attention scores
attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
attention_probs = torch.softmax(attention_scores, dim=-1)
# Apply the attention weights to the value
attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)
# Apply a linear layer to the output
x = self.fc_out(attention_output)
return x
if __name__ == '__main__':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
pretrained_bert = BertModel.from_pretrained('bert-base-uncased')
input_texts = ["This is a test sentence.", "Here is another test sentence."]
input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,
return_tensors='pt') for text in input_texts]
input_ids = torch.cat(input_ids, dim=0) # Concatenate and add batch dimension
with torch.no_grad():
embedded_output = pretrained_bert(input_ids)[0] # Get the output of the BERT model
print(embedded_output.size()) # Output shape should be (2, 10, embedding_dim)
embed_dim = embedded_output.size(-1)
num_heads = 4
model = BITDDDAttention(embed_dim, num_heads)
output = model(embedded_output)
print(output.size()) # Output shape should be (2, 10, embed_dim)
四.总结
上述代码可以在本地 CPU/GPU 环境跑起来,大家可以自己打断点熟悉整个过程维度的变化,计算的流程,Multi-Head Attention 分多个 head 计算不同 token 的注意力权重并加权求和,对于 Decoder-Only 的架构,其还会添加 Causal Mask 保证前面的文字看不到后面的文字。本文先介绍到 Transformer 的输出,后续我们介绍如何通过 Transformer 最后一层 lm_head 的输出计算 next_token 的概率并计算交叉熵 loss。
更多推荐
所有评论(0)