Files
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

259 lines
9.6 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from torch import nn
import torch.distributed as dist
from transformers import Qwen3Config
from nanovllm.layers.activation import SiluAndMul
from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
class Qwen3Attention(nn.Module):
"""Qwen3 的注意力层。
支持 GQAGrouped Query Attention):num_kv_heads 可以小于 num_heads
多个 query head 共享同一组 KV head,减少 KV cache 的显存占用。
当模型配置中没有 qkv_bias 时(Qwen3 默认无 bias),会在 Q 和 K 投影后
添加 RMSNormq_norm, k_norm),这是 Qwen3 的特殊设计。
支持张量并行(TP):QKV 投影按列切分,O 投影按行切分。
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: dict | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size # 当前 TP rank 拥有的 query head 数
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size # 当前 TP rank 拥有的 KV head 数
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5 # 注意力缩放因子: 1/sqrt(d_k)
self.qkv_bias = qkv_bias
# QKV 融合投影:将 hidden_states 投影为 Q、K、V,按 TP 切分
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
# 输出投影:将注意力输出映射回 hidden_size,按 TP 行切分
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
if isinstance(rope_scaling, dict):
rope_theta = rope_scaling.get("rope_theta", rope_theta)
# 旋转位置编码(RoPE
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
# Qwen3 无 bias 时使用 Q/K normpost-normalization
if not self.qkv_bias:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
# 将 QKV 融合结果拆分为 Q、K、V
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
if not self.qkv_bias:
# Qwen3 特有:对 Q 和 K 做 RMSNorm
q = self.q_norm(q)
k = self.k_norm(k)
# 应用旋转位置编码
q, k = self.rotary_emb(positions, q, k)
# 注意力计算(含 KV cache 读写)
o = self.attn(q, k, v)
# 输出投影
output = self.o_proj(o.flatten(1, -1))
return output
class Qwen3MLP(nn.Module):
"""Qwen3 的前馈网络(MLP),使用 SwiGLU 激活函数。
结构: hidden_size → 2×intermediate_size (gate + up) → intermediate_size → hidden_size
其中 gate 和 up 投影融合为一个矩阵乘法,然后 SiLU(gate) * up。
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
# 融合 gate_proj 和 up_proj 为一个矩阵,减少一次 kernel launch
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2, # gate 和 up 各输出 intermediate_size
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul() # SiLU(gate) * up
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen3DecoderLayer(nn.Module):
"""Qwen3 的单个 Transformer 解码层。
结构: Input RMSNorm → Self-Attention → Residual → Post-Attention RMSNorm → MLP → Residual
使用 Pre-Norm 架构(先归一化再进入子层),并将残差连接的计算融合到 RMSNorm 中以节省显存。
"""
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.self_attn = Qwen3Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', True),
head_dim=getattr(config, 'head_dim', None),
rope_theta=getattr(config, "rope_theta", 1000000),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = Qwen3MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# 残差连接融合到 RMSNorm 中:residual = hidden_states + residual, output = RMSNorm(residual)
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3Model(nn.Module):
"""Qwen3 的 Transformer 主体:Embedding → N × DecoderLayer → Final RMSNorm。"""
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
# 最终的 RMSNorm 也融合了残差加法
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 因果语言模型:用于文本生成。
包含映射表,将 HuggingFace 的独立权重名(q_proj, k_proj, v_proj, gate_proj, up_proj
映射到本项目融合后的权重名(qkv_proj, gate_up_proj),以便正确加载权重。
"""
# 融合模块的映射关系:HF 权重名 → (本项目模块名, shard_id)
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen3Config
) -> None:
super().__init__()
self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
# 如果模型配置了权重共享(tie),LM Head 和 Embedding 使用同一个权重矩阵
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""将最后一层隐藏状态通过 LM Head 转换为词表 logits。"""
return self.lm_head(hidden_states)