import torch from torch import nn import torch.distributed as dist from transformers import Qwen3Config from nanovllm.layers.activation import SiluAndMul from nanovllm.layers.attention import Attention from nanovllm.layers.layernorm import RMSNorm from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear from nanovllm.layers.rotary_embedding import get_rope from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead class Qwen3Attention(nn.Module): """Qwen3 的注意力层。 支持 GQA(Grouped Query Attention):num_kv_heads 可以小于 num_heads, 多个 query head 共享同一组 KV head,减少 KV cache 的显存占用。 当模型配置中没有 qkv_bias 时(Qwen3 默认无 bias),会在 Q 和 K 投影后 添加 RMSNorm(q_norm, k_norm),这是 Qwen3 的特殊设计。 支持张量并行(TP):QKV 投影按列切分,O 投影按行切分。 """ def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, head_dim: int | None = None, rms_norm_eps: float = 1e-06, qkv_bias: bool = False, rope_theta: float = 10000, rope_scaling: dict | None = None, ) -> None: super().__init__() tp_size = dist.get_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size # 当前 TP rank 拥有的 query head 数 self.total_num_kv_heads = num_kv_heads assert self.total_num_kv_heads % tp_size == 0 self.num_kv_heads = self.total_num_kv_heads // tp_size # 当前 TP rank 拥有的 KV head 数 self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim ** -0.5 # 注意力缩放因子: 1/sqrt(d_k) self.qkv_bias = qkv_bias # QKV 融合投影:将 hidden_states 投影为 Q、K、V,按 TP 切分 self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, ) # 输出投影:将注意力输出映射回 hidden_size,按 TP 行切分 self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, ) if isinstance(rope_scaling, dict): rope_theta = rope_scaling.get("rope_theta", rope_theta) # 旋转位置编码(RoPE) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=rope_theta, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, ) # Qwen3 无 bias 时使用 Q/K norm(post-normalization) if not self.qkv_bias: self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) # 将 QKV 融合结果拆分为 Q、K、V q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) if not self.qkv_bias: # Qwen3 特有:对 Q 和 K 做 RMSNorm q = self.q_norm(q) k = self.k_norm(k) # 应用旋转位置编码 q, k = self.rotary_emb(positions, q, k) # 注意力计算(含 KV cache 读写) o = self.attn(q, k, v) # 输出投影 output = self.o_proj(o.flatten(1, -1)) return output class Qwen3MLP(nn.Module): """Qwen3 的前馈网络(MLP),使用 SwiGLU 激活函数。 结构: hidden_size → 2×intermediate_size (gate + up) → intermediate_size → hidden_size 其中 gate 和 up 投影融合为一个矩阵乘法,然后 SiLU(gate) * up。 """ def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, ) -> None: super().__init__() # 融合 gate_proj 和 up_proj 为一个矩阵,减少一次 kernel launch self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, # gate 和 up 各输出 intermediate_size bias=False, ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, ) assert hidden_act == "silu" self.act_fn = SiluAndMul() # SiLU(gate) * up def forward(self, x): gate_up = self.gate_up_proj(x) x = self.act_fn(gate_up) x = self.down_proj(x) return x class Qwen3DecoderLayer(nn.Module): """Qwen3 的单个 Transformer 解码层。 结构: Input RMSNorm → Self-Attention → Residual → Post-Attention RMSNorm → MLP → Residual 使用 Pre-Norm 架构(先归一化再进入子层),并将残差连接的计算融合到 RMSNorm 中以节省显存。 """ def __init__( self, config: Qwen3Config, ) -> None: super().__init__() self.self_attn = Qwen3Attention( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, max_position=config.max_position_embeddings, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, 'attention_bias', True), head_dim=getattr(config, 'head_dim', None), rope_theta=getattr(config, "rope_theta", 1000000), rope_scaling=getattr(config, "rope_scaling", None), ) self.mlp = Qwen3MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # 残差连接融合到 RMSNorm 中:residual = hidden_states + residual, output = RMSNorm(residual) if residual is None: hidden_states, residual = self.input_layernorm(hidden_states), hidden_states else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class Qwen3Model(nn.Module): """Qwen3 的 Transformer 主体:Embedding → N × DecoderLayer → Final RMSNorm。""" def __init__( self, config: Qwen3Config, ) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for layer in self.layers: hidden_states, residual = layer(positions, hidden_states, residual) # 最终的 RMSNorm 也融合了残差加法 hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Qwen3ForCausalLM(nn.Module): """Qwen3 因果语言模型:用于文本生成。 包含映射表,将 HuggingFace 的独立权重名(q_proj, k_proj, v_proj, gate_proj, up_proj) 映射到本项目融合后的权重名(qkv_proj, gate_up_proj),以便正确加载权重。 """ # 融合模块的映射关系:HF 权重名 → (本项目模块名, shard_id) packed_modules_mapping = { "q_proj": ("qkv_proj", "q"), "k_proj": ("qkv_proj", "k"), "v_proj": ("qkv_proj", "v"), "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), } def __init__( self, config: Qwen3Config ) -> None: super().__init__() self.model = Qwen3Model(config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) # 如果模型配置了权重共享(tie),LM Head 和 Embedding 使用同一个权重矩阵 if config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: return self.model(input_ids, positions) def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor: """将最后一层隐藏状态通过 LM Head 转换为词表 logits。""" return self.lm_head(hidden_states)