transformers/models/qwen2_moe/modeling_qwen2_moe.py[Qwen2-MOE源码解析]
我会用中文逐行解释这段代码,并将其与Qwen2论文,特别是MoE(混合专家)部分联系起来:python复制这定义了一个新的类它继承自PyTorch的nn.Module。这个类代表了Qwen2 MoE模型中专家的MLP(多层感知器)部分。python复制初始化函数接收一个config对象和一个可选的。它基于配置设置基本属性。参数允许灵活设置中间层的大小,这对论文中提到的"专家粒度"概念很重要。pyt
class Qwen2MLP(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
我会用中文逐行解释这段代码,并将其与Qwen2论文,特别是MoE(混合专家)部分联系起来:
python
复制
class Qwen2MoeMLP(nn.Module):
这定义了一个新的类Qwen2MoeMLP
,它继承自PyTorch的nn.Module
。这个类代表了Qwen2 MoE模型中专家的MLP(多层感知器)部分。
python
复制
def __init__(self, config, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size
初始化函数接收一个config
对象和一个可选的intermediate_size
。它基于配置设置基本属性。intermediate_size
参数允许灵活设置中间层的大小,这对论文中提到的"专家粒度"概念很重要。
python
复制
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
这些行定义了三个线性投影:
gate_proj
: 用于门控机制。up_proj
: 上采样投影。down_proj
: 下采样投影。
这些投影是论文中提到的SwiGLU激活的一部分。
python
复制
self.act_fn = ACT2FN[config.hidden_act]
这根据配置设置激活函数,很可能是论文中提到的SwiGLU激活。
python
复制
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
这是MLP的前向传播。它应用了SwiGLU激活:gate_proj
的输出通过激活函数,然后与up_proj
的输出进行元素级乘法。结果再通过down_proj
。
python
复制
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
这个函数用于实现群组查询注意力(GQA),论文中提到这是一种优化推理过程中KV缓存使用的方法。
python
复制
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
它提取输入张量的形状,如果不需要重复,则直接返回输入。
python
复制
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
这些行执行键值状态的实际重复。它在新维度上扩展张量,然后重塑它以有效地重复键值头。这与论文中提到的GQA实现相关,优化了推理过程中的KV缓存使用。
这段代码实现了Qwen2 MoE模型的核心组件,包括专家MLP结构和GQA机制。它与论文中描述的使用SwiGLU激活、实现群组查询注意力以及MoE设置中的灵活专家配置相一致。
class Qwen2MoeAttention(nn.Module):
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe
class Qwen2MoeAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = Qwen2MoeRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
我会用中文逐行解释这段代码,并结合Qwen2技术报告:
python
复制
class Qwen2MoeAttention(nn.Module):
这个类定义了Qwen2混合专家(MoE)模型的注意力机制。正如报告中提到的,Qwen2包括密集模型和MoE模型。
python
复制
def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):
构造函数接受一个配置对象和一个可选的层索引。层索引对自回归解码中的缓存很重要。
python
复制
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
这些行设置了注意力机制的维度。报告中提到使用分组查询注意力(GQA),这里通过为注意力头和键值头设置不同的数量来实现。
python
复制
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
这些与位置编码有关。报告中提到使用旋转位置嵌入(RoPE)。
python
复制
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
这些是查询、键、值和输出的投影层。注意键和值投影使用num_key_value_heads
,实现了报告中提到的GQA。
python
复制
self.rotary_emb = Qwen2MoeRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
这初始化了旋转嵌入,实现了报告中提到的RoPE。
python
复制
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None):
这是注意力机制的前向传递。它包括用于缓存和输出注意力权重的参数,这对自回归生成和模型分析很有用。
python
复制
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
这些行将输入隐藏状态投影为查询、键和值状态。
python
复制
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
这些行重塑状态以进行多头注意力,实现了报告中提到的GQA结构。
python
复制
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
这些行应用了报告中提到的旋转位置嵌入(RoPE)。
python
复制
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
这些行实现了GQA的键值共享方面,为多个查询头重复键和值。
函数的其余部分实现了标准的注意力机制,包括掩码、softmax、dropout和最终投影。这与报告中提到的通用Transformer架构一致。
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
"leaky_relu": nn.LeakyReLU,
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu2": ReLUSquaredActivation,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": nn.SiLU,
"swish": nn.SiLU,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
"""
Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
会逐行解释这段代码,并结合Qwen2技术报告:
python
复制
class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
这个类继承自Qwen2MoeAttention
,实现了Qwen2 MoE模型的Flash Attention 2版本。Flash Attention是一种优化的注意力计算方法,能显著提高大规模模型的训练效率。
python
复制
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
构造函数初始化了一个标志,用于处理不同版本Flash Attention的掩码对齐问题。
python
复制
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None):
这是前向传播函数,参数与标准注意力机制相同。
python
复制
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
这些行计算查询、键和值状态,与标准注意力相同。
python
复制
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
这里重塑张量以适应多头注意力结构,实现了报告中提到的分组查询注意力(GQA)。
python
复制
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
# ... (缓存逻辑)
这部分处理键值缓存,用于提高自回归生成的效率。
python
复制
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
这些行应用了旋转位置编码(RoPE),这是Qwen2模型的一个重要特征。
python
复制
if past_key_value is not None:
# ... (滑动窗口注意力逻辑)
这部分实现了滑动窗口注意力,这是Qwen2用于处理长序列的一个关键技术。
python
复制
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
这实现了GQA中的键值共享。
python
复制
if (self.config.use_sliding_window and
getattr(self.config, "sliding_window", None) is not None and
self.layer_idx >= self.config.max_window_layers):
sliding_window = self.config.sliding_window
else:
sliding_window = None
这部分决定是否使用滑动窗口注意力,这取决于层的索引和配置。
python
复制
attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len,
dropout=dropout_rate, sliding_window=sliding_window,
is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
这是Flash Attention 2的核心调用,它高效地计算注意力输出。
python
复制
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
最后,注意力输出被重塑并通过输出投影层。
总的来说,这个类实现了Qwen2 MoE模型的Flash Attention 2版本,包含了报告中提到的许多关键特性,如GQA、RoPE和滑动窗口注意力。它旨在提高大规模模型的计算效率,特别是在处理长序列时。
下面是对上述代码的逐行解释,结合Qwen2报告中的一些相关信息:
python
复制
# 从transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention复制,Mixtral替换为Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
"""
Qwen2Moe注意力模块,使用torch.nn.functional.scaled_dot_product_attention。该模块继承自
`Qwen2MoeAttention`,因为模块的权重保持不变。唯一的变化是在前向传播中适应SDPA API。
"""
# 从Qwen2MoeAttention.forward适配
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: 一旦实现了`model.config.attn_implementation = "manual"`,改进此警告。
logger.warning_once(
"Qwen2MoeModel正在使用Qwen2MoeSdpaAttention,但`torch.nn.functional.scaled_dot_product_attention`不支持`output_attentions=True`。回退到手动注意力实现,"
'但从Transformers版本v5.0.0起,需要指定手动实现。使用参数`attn_implementation="eager"`加载模型时可以移除此警告。'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # 特定于RoPE模型
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None: # 无论长度如何,我们只截取它
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA的内存高效后端目前(torch==2.1.2)在使用自定义attn_mask时存在非连续输入的bug,
# 参考:https://github.com/pytorch/pytorch/issues/112577。
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# 我们通过这个`is_causal`if语句而不是SDPA中的内联条件赋值来调度到SDPA的Flash Attention或Efficient内核,
# 以支持torch.compile的动态形状和完全图形选项。内联条件会阻止动态形状的编译。
# q_len > 1 是必要的,以匹配AttentionMaskConverter.to_causal_4d在q_len == 1时不会创建因果掩码的情况。
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
代码解释
- 类定义:定义了一个名为
Qwen2MoeSdpaAttention
的类,继承自Qwen2MoeAttention
。 - 文档字符串:描述了该类是如何使用
torch.nn.functional.scaled_dot_product_attention
实现注意力机制,同时保留了原模块的权重。 - 前向传递方法 (
forward
):定义了前向传递方法,该方法计算注意力机制的输出。- 输入参数:包括
hidden_states
、attention_mask
、position_ids
、past_key_value
、output_attentions
、use_cache
、cache_position
。 - 警告处理:如果
output_attentions
为True
,发出警告并调用父类的forward
方法。 - 计算查询、键、值状态:通过线性投影计算查询、键、值状态,然后调整其形状以匹配多头注意力的要求。
- 旋转位置嵌入:使用旋转位置嵌入将位置信息加入到查询和键状态中。
- 更新过去的键值:如果存在过去的键值,则更新它们。
- 重复键值:重复键值以匹配注意力头的数量。
- 因果掩码:生成因果掩码,如果存在注意力掩码,则截取其长度。
- 修复输入连续性:如果使用CUDA设备并且存在注意力掩码,则确保查询、键、值状态是连续的。
- 判断是否使用因果掩码:根据条件判断是否使用因果掩码。
- 计算注意力输出:使用
torch.nn.functional.scaled_dot_product_attention
计算注意力输出。 - 调整输出形状:调整注意力输出的形状以匹配输入的形状。
- 返回结果:返回注意力输出、None和过去的键值。
- 输入参数:包括
结合Qwen2报告
Qwen2系列模型,包括Qwen2-72B等,展示了在多项基准测试中的卓越性能。这些模型的注意力机制是其关键组件,代码中的Qwen2MoeSdpaAttention
类正是实现这种注意力机制的核心部分之一。报告中提到的Qwen2模型在语言理解、生成、多语言能力等方面表现优异,这些能力部分归功于高效的注意力机制实现。
QWEN2MOE_ATTENTION_CLASSES = {
"eager": Qwen2MoeAttention,
"flash_attention_2": Qwen2MoeFlashAttention2,
"sdpa": Qwen2MoeSdpaAttention,
}
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList(
[Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
)
self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
好的,我会逐行解释这段代码:
python
复制
QWEN2MOE_ATTENTION_CLASSES = {
"eager": Qwen2MoeAttention,
"flash_attention_2": Qwen2MoeFlashAttention2,
"sdpa": Qwen2MoeSdpaAttention,
}
这个字典定义了Qwen2 MoE模型可用的不同注意力机制类。包括标准注意力、Flash Attention 2和SDPA(可能是某种优化的注意力变体)。
python
复制
class Qwen2MoeSparseMoeBlock(nn.Module):
这个类定义了Qwen2 MoE模型的稀疏混合专家(Sparse MoE)块。
python
复制
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
初始化函数设置了专家数量、每个token使用的专家数量,以及是否对top-k概率进行归一化。
python
复制
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
这定义了路由门,它决定每个输入应该被发送到哪些专家。
python
复制
self.experts = nn.ModuleList(
[Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
)
这创建了专家列表,每个专家是一个MLP。
python
复制
self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
这定义了一个共享专家和它的门控机制,可能用于处理所有输入。
python
复制
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
前向传播函数,处理输入隐藏状态。
python
复制
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
重塑输入以便于处理。
python
复制
router_logits = self.gate(hidden_states)
计算路由逻辑。
python
复制
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
计算路由权重并选择top-k专家。
python
复制
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
如果需要,对top-k概率进行归一化,并确保数据类型一致。
python
复制
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
初始化最终输出张量。
python
复制
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
创建专家掩码,用于后续的专家选择。
这段代码实现了Qwen2 MoE模型中的稀疏混合专家机制,包括专家路由和选择过程。这是实现大规模语言模型中的条件计算的关键部分,允许模型根据输入动态选择不同的专家子网络。
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
final_hidden_states = final_hidden_states + shared_expert_output
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
我会继续逐行解释这段代码:
python
复制
for expert_idx in range(self.num_experts):
这个循环遍历所有的专家。
python
复制
expert_layer = self.experts[expert_idx]
获取当前专家的层。
python
复制
idx, top_x = torch.where(expert_mask[expert_idx])
找出被路由到当前专家的所有输入的索引。
python
复制
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
选择被路由到当前专家的隐藏状态。
python
复制
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
将选中的隐藏状态通过当前专家,并乘以相应的路由权重。
python
复制
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
将当前专家的输出添加到最终隐藏状态中,使用index_add_
确保正确的位置。
python
复制
shared_expert_output = self.shared_expert(hidden_states)
计算共享专家的输出。
python
复制
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
应用共享专家的门控机制。
python
复制
final_hidden_states = final_hidden_states + shared_expert_output
将共享专家的输出添加到最终隐藏状态中。
python
复制
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
将最终隐藏状态重塑回原始的批次大小和序列长度。
python
复制
return final_hidden_states, router_logits
返回最终的隐藏状态和路由逻辑。
这段代码实现了Sparse MoE的核心逻辑:它遍历每个专家,将相应的输入传递给专家,然后将专家的输出加权合并。同时,它还包含了一个共享专家,可能用于捕获所有输入的通用特征。这种方法允许模型根据输入动态选择不同的计算路径,potentially提高模型的能力和效率。
class Qwen2MoeDecoderLayer(nn.Module):
def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen2MoeSparseMoeBlock(config)
else:
self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)
self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
我会继续逐行解释这段代码:
python
复制
class Qwen2MoeDecoderLayer(nn.Module):
这定义了Qwen2 MoE模型的解码器层类。
python
复制
def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
初始化函数,接收配置和层索引作为参数,并设置隐藏大小。
python
复制
self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
根据配置中指定的注意力实现,初始化自注意力层。
python
复制
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen2MoeSparseMoeBlock(config)
else:
self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)
这段代码决定使用稀疏MoE块还是标准MLP。条件包括:
- 当前层不在仅MLP层列表中
- 专家数量大于0
- 当前层索引满足稀疏步长要求
python
复制
self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
初始化输入层归一化和注意力后的层归一化。
python
复制
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
这是前向传播函数的定义,它接收多个可选参数:
hidden_states
: 输入的隐藏状态attention_mask
: 注意力掩码position_ids
: 位置编码past_key_value
: 用于缓存的过去的键值对output_attentions
: 是否输出注意力权重output_router_logits
: 是否输出路由器逻辑use_cache
: 是否使用缓存cache_position
: 缓存的位置
这个函数将在下一部分代码中实现具体的前向传播逻辑。
这个类实现了Qwen2 MoE模型的一个解码器层,包含了自注意力机制和基于条件的MLP或MoE块。它灵活地允许在不同层使用不同的计算结构,这是大规模语言模型中常见的优化策略。
4o
这段代码定义了一个类 ClassInstantier
来管理激活函数的实例化,并提供了一个字典 ACT2CLS
来映射激活函数名称和对应的类或类及其初始化参数。同时,还定义了一个函数 get_activation
用于根据名称获取激活函数的实例。最后,在代码的末尾,通过调用 get_activation
函数实例化了一些激活函数。
下面是逐行解释:
python
复制
class ClassInstantier(OrderedDict):
定义一个名为 ClassInstantier
的类,它继承了 OrderedDict
。
python
复制
def __getitem__(self, key):
重载 __getitem__
方法,使得在通过键访问字典项时进行自定义处理。
python
复制
content = super().__getitem__(key)
从父类 OrderedDict
中获取指定键 key
对应的值 content
。
python
复制
cls, kwargs = content if isinstance(content, tuple) else (content, {})
检查 content
是否是一个元组。如果是,将其解包为 cls
和 kwargs
;如果不是,将 content
赋值给 cls
,并将 kwargs
设为空字典。
python
复制
return cls(**kwargs)
使用 cls
和 kwargs
创建实例并返回。
python
复制
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
"leaky_relu": nn.LeakyReLU,
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu2": ReLUSquaredActivation,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": nn.SiLU,
"swish": nn.SiLU,
"tanh": nn.Tanh,
}
定义一个字典 ACT2CLS
,其中键是激活函数的名称,值是对应的类或一个包含类和初始化参数的元组。
python
复制
ACT2FN = ClassInstantier(ACT2CLS)
使用 ClassInstantier
类实例化 ACT2CLS
,生成 ACT2FN
。
python
复制
def get_activation(activation_string):
定义一个函数 get_activation
,用于获取激活函数实例。
python
复制
if activation_string in ACT2FN:
return ACT2FN[activation_string]
如果 activation_string
在 ACT2FN
中,返回对应的激活函数实例。
python
复制
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
否则,抛出 KeyError
异常,提示未找到对应的激活函数。
python
复制
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
为了向后兼容,使用 get_activation
函数实例化若干激活函数,并赋值给相应的变量。
更多推荐
所有评论(0)