class Qwen2MLP(nn.Module):

# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
class Qwen2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_state):
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

我会用中文逐行解释这段代码,并将其与Qwen2论文,特别是MoE(混合专家)部分联系起来:

python

复制

class Qwen2MoeMLP(nn.Module):

这定义了一个新的类Qwen2MoeMLP,它继承自PyTorch的nn.Module。这个类代表了Qwen2 MoE模型中专家的MLP(多层感知器)部分。

python

复制

    def __init__(self, config, intermediate_size=None):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = intermediate_size

初始化函数接收一个config对象和一个可选的intermediate_size。它基于配置设置基本属性。intermediate_size参数允许灵活设置中间层的大小,这对论文中提到的"专家粒度"概念很重要。

python

复制

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

这些行定义了三个线性投影:

  1. gate_proj: 用于门控机制
  2. up_proj: 上采样投影
  3. down_proj: 下采样投影。
    这些投影是论文中提到的SwiGLU激活的一部分。

python

复制

        self.act_fn = ACT2FN[config.hidden_act]

这根据配置设置激活函数,很可能是论文中提到的SwiGLU激活

python

复制

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

这是MLP的前向传播。它应用了SwiGLU激活:gate_proj的输出通过激活函数,然后与up_proj的输出进行元素级乘法。结果再通过down_proj

python

复制

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

这个函数用于实现群组查询注意力(GQA),论文中提到这是一种优化推理过程中KV缓存使用的方法。

python

复制

    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states

提取输入张量的形状,如果不需要重复,则直接返回输入

python

复制

    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

这些行执行键值状态的实际重复。它在新维度上扩展张量,然后重塑它以有效地重复键值头。这与论文中提到的GQA实现相关,优化了推理过程中的KV缓存使用。

这段代码实现了Qwen2 MoE模型的核心组件,包括专家MLP结构GQA机制。它与论文中描述的使用SwiGLU激活、实现群组查询注意力以及MoE设置中的灵活专家配置相一致。

class Qwen2MoeAttention(nn.Module):

# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe
class Qwen2MoeAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = Qwen2MoeRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

我会用中文逐行解释这段代码,并结合Qwen2技术报告:

python

复制

class Qwen2MoeAttention(nn.Module):

这个类定义了Qwen2混合专家(MoE)模型的注意力机制。正如报告中提到的,Qwen2包括密集模型和MoE模型。

python

复制

def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):

构造函数接受一个配置对象和一个可选的层索引。层索引对自回归解码中的缓存很重要。

python

复制

self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads

这些行设置了注意力机制的维度。报告中提到使用分组查询注意力(GQA),这里通过为注意力头键值头设置不同的数量来实现。

python

复制

self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta

这些与位置编码有关。报告中提到使用旋转位置嵌入(RoPE)

python

复制

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

这些是查询、键、值输出的投影层。注意键和值投影使用num_key_value_heads,实现了报告中提到的GQA。

python

复制

self.rotary_emb = Qwen2MoeRotaryEmbedding(
    self.head_dim,
    max_position_embeddings=self.max_position_embeddings,
    base=self.rope_theta,
)

这初始化了旋转嵌入,实现了报告中提到的RoPE。

python

复制

def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Cache] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.LongTensor] = None):

这是注意力机制的前向传递。它包括用于缓存和输出注意力权重的参数,这对自回归生成和模型分析很有用。

python

复制

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

这些行将输入隐藏状态投影为查询、键和值状态。

python

复制

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

这些行重塑状态以进行多头注意力,实现了报告中提到的GQA结构。

python

复制

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

这些行应用了报告中提到的旋转位置嵌入(RoPE)。

python

复制

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

这些行实现了GQA的键值共享方面,为多个查询头重复键和值。

函数的其余部分实现了标准的注意力机制,包括掩码、softmax、dropout和最终投影。这与报告中提到的通用Transformer架构一致。

class ClassInstantier(OrderedDict):
    def __getitem__(self, key):
        content = super().__getitem__(key)
        cls, kwargs = content if isinstance(content, tuple) else (content, {})
        return cls(**kwargs)


ACT2CLS = {
    "gelu": GELUActivation,
    "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
    "gelu_fast": FastGELUActivation,
    "gelu_new": NewGELUActivation,
    "gelu_python": (GELUActivation, {"use_gelu_python": True}),
    "gelu_pytorch_tanh": PytorchGELUTanh,
    "gelu_accurate": AccurateGELUActivation,
    "laplace": LaplaceActivation,
    "leaky_relu": nn.LeakyReLU,
    "linear": LinearActivation,
    "mish": MishActivation,
    "quick_gelu": QuickGELUActivation,
    "relu": nn.ReLU,
    "relu2": ReLUSquaredActivation,
    "relu6": nn.ReLU6,
    "sigmoid": nn.Sigmoid,
    "silu": nn.SiLU,
    "swish": nn.SiLU,
    "tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)


def get_activation(activation_string):
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")


# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
    """
    Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
    as the weights of the module stays untouched. The only required change would be on the forward pass
    where it needs to correctly call the public API of flash attention and deal with padding tokens
    in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
    config.max_window_layers layers.
    """

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

        # Because the input can be padded, the absolute sequence length depends on the max position id.
        rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
        cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # Activate slicing cache only if the config has a value `sliding_windows` attribute
            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
            if (
                getattr(self.config, "sliding_window", None) is not None
                and kv_seq_len > self.config.sliding_window
                and cache_has_contents
            ):
                slicing_tokens = 1 - self.config.sliding_window

                past_key = past_key_value[self.layer_idx][0]
                past_value = past_key_value[self.layer_idx][1]

                past_key = past_key[:, :, slicing_tokens:, :].contiguous()
                past_value = past_value[:, :, slicing_tokens:, :].contiguous()

                if past_key.shape[-2] != self.config.sliding_window - 1:
                    raise ValueError(
                        f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
                        f" {past_key.shape}"
                    )

                if attention_mask is not None:
                    attention_mask = attention_mask[:, slicing_tokens:]
                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        dropout_rate = 0.0 if not self.training else self.attention_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reashape to the expected shape for Flash Attention
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        if (
            self.config.use_sliding_window
            and getattr(self.config, "sliding_window", None) is not None
            and self.layer_idx >= self.config.max_window_layers
        ):
            sliding_window = self.config.sliding_window
        else:
            sliding_window = None

        attn_output = _flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            sliding_window=sliding_window,
            is_causal=self.is_causal,
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
        )

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

会逐行解释这段代码,并结合Qwen2技术报告:

python

复制

class Qwen2MoeFlashAttention2(Qwen2MoeAttention):

这个类继承自Qwen2MoeAttention,实现了Qwen2 MoE模型的Flash Attention 2版本。Flash Attention是一种优化的注意力计算方法,能显著提高大规模模型的训练效率。

python

复制

def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

构造函数初始化了一个标志,用于处理不同版本Flash Attention的掩码对齐问题。

python

复制

def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Cache] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.LongTensor] = None):

这是前向传播函数,参数与标准注意力机制相同。

python

复制

bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

这些行计算查询、键和值状态,与标准注意力相同。

python

复制

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

这里重塑张量以适应多头注意力结构,实现了报告中提到的分组查询注意力(GQA)。

python

复制

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
    # ... (缓存逻辑)

这部分处理键值缓存,用于提高自回归生成的效率。

python

复制

rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

这些行应用了旋转位置编码(RoPE),这是Qwen2模型的一个重要特征。

python

复制

if past_key_value is not None:
    # ... (滑动窗口注意力逻辑)

这部分实现了滑动窗口注意力,这是Qwen2用于处理长序列的一个关键技术。

python

复制

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

这实现了GQA中的键值共享

python

复制

if (self.config.use_sliding_window and 
    getattr(self.config, "sliding_window", None) is not None and 
    self.layer_idx >= self.config.max_window_layers):
    sliding_window = self.config.sliding_window
else:
    sliding_window = None

这部分决定是否使用滑动窗口注意力,这取决于层的索引和配置。

python

复制

attn_output = _flash_attention_forward(
    query_states, key_states, value_states, attention_mask, q_len,
    dropout=dropout_rate, sliding_window=sliding_window,
    is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask,
)

这是Flash Attention 2的核心调用,它高效地计算注意力输出

python

复制

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)

最后,注意力输出被重塑并通过输出投影层。

总的来说,这个类实现了Qwen2 MoE模型的Flash Attention 2版本,包含了报告中提到的许多关键特性,如GQA、RoPE和滑动窗口注意力。它旨在提高大规模模型的计算效率,特别是在处理长序列时。

下面是对上述代码的逐行解释,结合Qwen2报告中的一些相关信息:

python

复制

# 从transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention复制,Mixtral替换为Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
    """
    Qwen2Moe注意力模块,使用torch.nn.functional.scaled_dot_product_attention。该模块继承自
    `Qwen2MoeAttention`,因为模块的权重保持不变。唯一的变化是在前向传播中适应SDPA API。
    """

    # 从Qwen2MoeAttention.forward适配
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: 一旦实现了`model.config.attn_implementation = "manual"`,改进此警告。
            logger.warning_once(
                "Qwen2MoeModel正在使用Qwen2MoeSdpaAttention,但`torch.nn.functional.scaled_dot_product_attention`不支持`output_attentions=True`。回退到手动注意力实现,"
                '但从Transformers版本v5.0.0起,需要指定手动实现。使用参数`attn_implementation="eager"`加载模型时可以移除此警告。'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # 特定于RoPE模型
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:  # 无论长度如何,我们只截取它
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

        # SDPA的内存高效后端目前(torch==2.1.2)在使用自定义attn_mask时存在非连续输入的bug,
        # 参考:https://github.com/pytorch/pytorch/issues/112577。
        if query_states.device.type == "cuda" and attention_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # 我们通过这个`is_causal`if语句而不是SDPA中的内联条件赋值来调度到SDPA的Flash Attention或Efficient内核,
        # 以支持torch.compile的动态形状和完全图形选项。内联条件会阻止动态形状的编译。
        # q_len > 1 是必要的,以匹配AttentionMaskConverter.to_causal_4d在q_len == 1时不会创建因果掩码的情况。
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value

代码解释

  1. 类定义:定义了一个名为Qwen2MoeSdpaAttention的类,继承自Qwen2MoeAttention
  2. 文档字符串:描述了该类是如何使用torch.nn.functional.scaled_dot_product_attention实现注意力机制,同时保留了原模块的权重。
  3. 前向传递方法 (forward):定义了前向传递方法,该方法计算注意力机制的输出。
    • 输入参数:包括hidden_statesattention_maskposition_idspast_key_valueoutput_attentionsuse_cachecache_position
    • 警告处理:如果output_attentionsTrue,发出警告并调用父类的forward方法。
    • 计算查询、键、值状态:通过线性投影计算查询、键、值状态,然后调整其形状以匹配多头注意力的要求。
    • 旋转位置嵌入:使用旋转位置嵌入将位置信息加入到查询和键状态中。
    • 更新过去的键值:如果存在过去的键值,则更新它们。
    • 重复键值:重复键值以匹配注意力头的数量。
    • 因果掩码:生成因果掩码,如果存在注意力掩码,则截取其长度。
    • 修复输入连续性:如果使用CUDA设备并且存在注意力掩码,则确保查询、键、值状态是连续的。
    • 判断是否使用因果掩码:根据条件判断是否使用因果掩码。
    • 计算注意力输出:使用torch.nn.functional.scaled_dot_product_attention计算注意力输出。
    • 调整输出形状:调整注意力输出的形状以匹配输入的形状。
    • 返回结果:返回注意力输出、None和过去的键值。

结合Qwen2报告

Qwen2系列模型,包括Qwen2-72B等,展示了在多项基准测试中的卓越性能。这些模型的注意力机制是其关键组件,代码中的Qwen2MoeSdpaAttention类正是实现这种注意力机制的核心部分之一。报告中提到的Qwen2模型在语言理解、生成、多语言能力等方面表现优异,这些能力部分归功于高效的注意力机制实现。

QWEN2MOE_ATTENTION_CLASSES = {
    "eager": Qwen2MoeAttention,
    "flash_attention_2": Qwen2MoeFlashAttention2,
    "sdpa": Qwen2MoeSdpaAttention,
}


class Qwen2MoeSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob

        # gating
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
        self.experts = nn.ModuleList(
            [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
        )

        self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        if self.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

好的,我会逐行解释这段代码:

python

复制

QWEN2MOE_ATTENTION_CLASSES = {
    "eager": Qwen2MoeAttention,
    "flash_attention_2": Qwen2MoeFlashAttention2,
    "sdpa": Qwen2MoeSdpaAttention,
}

这个字典定义了Qwen2 MoE模型可用的不同注意力机制类。包括标准注意力、Flash Attention 2SDPA(可能是某种优化的注意力变体)

python

复制

class Qwen2MoeSparseMoeBlock(nn.Module):

这个类定义了Qwen2 MoE模型的稀疏混合专家(Sparse MoE)块。

python

复制

def __init__(self, config):
    super().__init__()
    self.num_experts = config.num_experts
    self.top_k = config.num_experts_per_tok
    self.norm_topk_prob = config.norm_topk_prob

初始化函数设置了专家数量、每个token使用的专家数量,以及是否对top-k概率进行归一化。

python

复制

self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

这定义了路由门,它决定每个输入应该被发送到哪些专家。

python

复制

self.experts = nn.ModuleList(
    [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
)

这创建了专家列表,每个专家是一个MLP

python

复制

self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

这定义了一个共享专家和它的门控机制,可能用于处理所有输入。

python

复制

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

前向传播函数,处理输入隐藏状态。

python

复制

batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

重塑输入以便于处理。

python

复制

router_logits = self.gate(hidden_states)

计算路由逻辑。

python

复制

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

计算路由权重选择top-k专家

python

复制

if self.norm_topk_prob:
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)

如果需要,对top-k概率进行归一化,并确保数据类型一致

python

复制

final_hidden_states = torch.zeros(
    (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

初始化最终输出张量

python

复制

expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

创建专家掩码,用于后续的专家选择。

这段代码实现了Qwen2 MoE模型中的稀疏混合专家机制,包括专家路由和选择过程。这是实现大规模语言模型中的条件计算的关键部分,允许模型根据输入动态选择不同专家子网络

# Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        shared_expert_output = self.shared_expert(hidden_states)
        shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

        final_hidden_states = final_hidden_states + shared_expert_output

        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

我会继续逐行解释这段代码:

python

复制

for expert_idx in range(self.num_experts):

这个循环遍历所有的专家。

python

复制

expert_layer = self.experts[expert_idx]

获取当前专家的层。

python

复制

idx, top_x = torch.where(expert_mask[expert_idx])

找出被路由到当前专家的所有输入的索引。

python

复制

current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)

选择被路由到当前专家的隐藏状态。

python

复制

current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

将选中的隐藏状态通过当前专家,并乘以相应的路由权重。

python

复制

final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

将当前专家的输出添加到最终隐藏状态中,使用index_add_确保正确的位置。

python

复制

shared_expert_output = self.shared_expert(hidden_states)

计算共享专家的输出。

python

复制

shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

应用共享专家的门控机制。

python

复制

final_hidden_states = final_hidden_states + shared_expert_output

将共享专家的输出添加到最终隐藏状态中。

python

复制

final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

将最终隐藏状态重塑回原始的批次大小和序列长度。

python

复制

return final_hidden_states, router_logits

返回最终的隐藏状态和路由逻辑。

这段代码实现了Sparse MoE的核心逻辑:它遍历每个专家,将相应的输入传递给专家,然后将专家的输出加权合并。同时,它还包含了一个共享专家,可能用于捕获所有输入的通用特征。这种方法允许模型根据输入动态选择不同的计算路径,potentially提高模型的能力和效率

class Qwen2MoeDecoderLayer(nn.Module):
    def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        if (layer_idx not in config.mlp_only_layers) and (
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen2MoeSparseMoeBlock(config)
        else:
            self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)

        self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        output_router_logits: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

我会继续逐行解释这段代码:

python

复制

class Qwen2MoeDecoderLayer(nn.Module):

这定义了Qwen2 MoE模型的解码器层类。

python

复制

def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
    super().__init__()
    self.hidden_size = config.hidden_size

初始化函数,接收配置和层索引作为参数,并设置隐藏大小。

python

复制

self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

根据配置中指定的注意力实现,初始化自注意力层。

python

复制

if (layer_idx not in config.mlp_only_layers) and (
    config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
    self.mlp = Qwen2MoeSparseMoeBlock(config)
else:
    self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)

这段代码决定使用稀疏MoE块还是标准MLP。条件包括:

  1. 当前层不在仅MLP层列表中
  2. 专家数量大于0
  3. 当前层索引满足稀疏步长要求

python

复制

self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

初始化输入层归一化注意力后的层归一化。

python

复制

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    output_router_logits: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

这是前向传播函数的定义,它接收多个可选参数:

  • hidden_states: 输入的隐藏状态
  • attention_mask: 注意力掩码
  • position_ids: 位置编码
  • past_key_value: 用于缓存的过去的键值对
  • output_attentions: 是否输出注意力权重
  • output_router_logits: 是否输出路由器逻辑
  • use_cache: 是否使用缓存
  • cache_position: 缓存的位置

这个函数将在下一部分代码中实现具体的前向传播逻辑。

这个类实现了Qwen2 MoE模型的一个解码器层,包含了自注意力机制和基于条件的MLP或MoE块。它灵活地允许在不同层使用不同的计算结构,这是大规模语言模型中常见的优化策略。

 

4o

这段代码定义了一个类 ClassInstantier 来管理激活函数的实例化,并提供了一个字典 ACT2CLS 映射激活函数名称对应的类或类及其初始化参数。同时,还定义了一个函数 get_activation 用于根据名称获取激活函数的实例。最后,在代码的末尾,通过调用 get_activation 函数实例化了一些激活函数。

下面是逐行解释:

python

复制

class ClassInstantier(OrderedDict):

定义一个名为 ClassInstantier 的类,它继承了 OrderedDict

python

复制

    def __getitem__(self, key):

重载 __getitem__ 方法,使得在通过键访问字典项时进行自定义处理。

python

复制

        content = super().__getitem__(key)

父类 OrderedDict 中获取指定键 key 对应的值 content

python

复制

        cls, kwargs = content if isinstance(content, tuple) else (content, {})

检查 content 是否是一个元组。如果是,将其解包为 cls 和 kwargs;如果不是,将 content 赋值给 cls,并将 kwargs 设为空字典。

python

复制

        return cls(**kwargs)

使用 cls 和 kwargs 创建实例并返回。

python

复制

ACT2CLS = {
    "gelu": GELUActivation,
    "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
    "gelu_fast": FastGELUActivation,
    "gelu_new": NewGELUActivation,
    "gelu_python": (GELUActivation, {"use_gelu_python": True}),
    "gelu_pytorch_tanh": PytorchGELUTanh,
    "gelu_accurate": AccurateGELUActivation,
    "laplace": LaplaceActivation,
    "leaky_relu": nn.LeakyReLU,
    "linear": LinearActivation,
    "mish": MishActivation,
    "quick_gelu": QuickGELUActivation,
    "relu": nn.ReLU,
    "relu2": ReLUSquaredActivation,
    "relu6": nn.ReLU6,
    "sigmoid": nn.Sigmoid,
    "silu": nn.SiLU,
    "swish": nn.SiLU,
    "tanh": nn.Tanh,
}

定义一个字典 ACT2CLS,其中键是激活函数的名称,值是对应的类或一个包含类和初始化参数的元组。

python

复制

ACT2FN = ClassInstantier(ACT2CLS)

使用 ClassInstantier 类实例化 ACT2CLS,生成 ACT2FN

python

复制

def get_activation(activation_string):

定义一个函数 get_activation,用于获取激活函数实例。

python

复制

    if activation_string in ACT2FN:
        return ACT2FN[activation_string]

如果 activation_string 在 ACT2FN 中,返回对应的激活函数实例

python

复制

    else:
        raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")

否则,抛出 KeyError 异常,提示未找到对应的激活函数。

python

复制

# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")

为了向后兼容,使用 get_activation 函数实例化若干激活函数,并赋值给相应的变量。

Logo

尧米是由西云算力与CSDN联合运营的AI算力和模型开源社区品牌,为基于DaModel智算平台的AI应用企业和泛AI开发者提供技术交流与成果转化平台。

更多推荐