ffd2defdfc
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
259 lines
9.6 KiB
Python
Executable File
259 lines
9.6 KiB
Python
Executable File
import torch
|
||
from torch import nn
|
||
import torch.distributed as dist
|
||
from transformers import Qwen3Config
|
||
|
||
from nanovllm.layers.activation import SiluAndMul
|
||
from nanovllm.layers.attention import Attention
|
||
from nanovllm.layers.layernorm import RMSNorm
|
||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||
from nanovllm.layers.rotary_embedding import get_rope
|
||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||
|
||
|
||
class Qwen3Attention(nn.Module):
|
||
"""Qwen3 的注意力层。
|
||
|
||
支持 GQA(Grouped Query Attention):num_kv_heads 可以小于 num_heads,
|
||
多个 query head 共享同一组 KV head,减少 KV cache 的显存占用。
|
||
|
||
当模型配置中没有 qkv_bias 时(Qwen3 默认无 bias),会在 Q 和 K 投影后
|
||
添加 RMSNorm(q_norm, k_norm),这是 Qwen3 的特殊设计。
|
||
|
||
支持张量并行(TP):QKV 投影按列切分,O 投影按行切分。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
hidden_size: int,
|
||
num_heads: int,
|
||
num_kv_heads: int,
|
||
max_position: int = 4096 * 32,
|
||
head_dim: int | None = None,
|
||
rms_norm_eps: float = 1e-06,
|
||
qkv_bias: bool = False,
|
||
rope_theta: float = 10000,
|
||
rope_scaling: dict | None = None,
|
||
) -> None:
|
||
super().__init__()
|
||
tp_size = dist.get_world_size()
|
||
self.total_num_heads = num_heads
|
||
assert self.total_num_heads % tp_size == 0
|
||
self.num_heads = self.total_num_heads // tp_size # 当前 TP rank 拥有的 query head 数
|
||
self.total_num_kv_heads = num_kv_heads
|
||
assert self.total_num_kv_heads % tp_size == 0
|
||
self.num_kv_heads = self.total_num_kv_heads // tp_size # 当前 TP rank 拥有的 KV head 数
|
||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||
self.q_size = self.num_heads * self.head_dim
|
||
self.kv_size = self.num_kv_heads * self.head_dim
|
||
self.scaling = self.head_dim ** -0.5 # 注意力缩放因子: 1/sqrt(d_k)
|
||
self.qkv_bias = qkv_bias
|
||
|
||
# QKV 融合投影:将 hidden_states 投影为 Q、K、V,按 TP 切分
|
||
self.qkv_proj = QKVParallelLinear(
|
||
hidden_size,
|
||
self.head_dim,
|
||
self.total_num_heads,
|
||
self.total_num_kv_heads,
|
||
bias=qkv_bias,
|
||
)
|
||
# 输出投影:将注意力输出映射回 hidden_size,按 TP 行切分
|
||
self.o_proj = RowParallelLinear(
|
||
self.total_num_heads * self.head_dim,
|
||
hidden_size,
|
||
bias=False,
|
||
)
|
||
if isinstance(rope_scaling, dict):
|
||
rope_theta = rope_scaling.get("rope_theta", rope_theta)
|
||
# 旋转位置编码(RoPE)
|
||
self.rotary_emb = get_rope(
|
||
self.head_dim,
|
||
rotary_dim=self.head_dim,
|
||
max_position=max_position,
|
||
base=rope_theta,
|
||
)
|
||
self.attn = Attention(
|
||
self.num_heads,
|
||
self.head_dim,
|
||
self.scaling,
|
||
self.num_kv_heads,
|
||
)
|
||
# Qwen3 无 bias 时使用 Q/K norm(post-normalization)
|
||
if not self.qkv_bias:
|
||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||
|
||
def forward(
|
||
self,
|
||
positions: torch.Tensor,
|
||
hidden_states: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
qkv = self.qkv_proj(hidden_states)
|
||
# 将 QKV 融合结果拆分为 Q、K、V
|
||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||
q = q.view(-1, self.num_heads, self.head_dim)
|
||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||
if not self.qkv_bias:
|
||
# Qwen3 特有:对 Q 和 K 做 RMSNorm
|
||
q = self.q_norm(q)
|
||
k = self.k_norm(k)
|
||
# 应用旋转位置编码
|
||
q, k = self.rotary_emb(positions, q, k)
|
||
# 注意力计算(含 KV cache 读写)
|
||
o = self.attn(q, k, v)
|
||
# 输出投影
|
||
output = self.o_proj(o.flatten(1, -1))
|
||
return output
|
||
|
||
|
||
class Qwen3MLP(nn.Module):
|
||
"""Qwen3 的前馈网络(MLP),使用 SwiGLU 激活函数。
|
||
|
||
结构: hidden_size → 2×intermediate_size (gate + up) → intermediate_size → hidden_size
|
||
其中 gate 和 up 投影融合为一个矩阵乘法,然后 SiLU(gate) * up。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
hidden_size: int,
|
||
intermediate_size: int,
|
||
hidden_act: str,
|
||
) -> None:
|
||
super().__init__()
|
||
# 融合 gate_proj 和 up_proj 为一个矩阵,减少一次 kernel launch
|
||
self.gate_up_proj = MergedColumnParallelLinear(
|
||
hidden_size,
|
||
[intermediate_size] * 2, # gate 和 up 各输出 intermediate_size
|
||
bias=False,
|
||
)
|
||
self.down_proj = RowParallelLinear(
|
||
intermediate_size,
|
||
hidden_size,
|
||
bias=False,
|
||
)
|
||
assert hidden_act == "silu"
|
||
self.act_fn = SiluAndMul() # SiLU(gate) * up
|
||
|
||
def forward(self, x):
|
||
gate_up = self.gate_up_proj(x)
|
||
x = self.act_fn(gate_up)
|
||
x = self.down_proj(x)
|
||
return x
|
||
|
||
|
||
class Qwen3DecoderLayer(nn.Module):
|
||
"""Qwen3 的单个 Transformer 解码层。
|
||
|
||
结构: Input RMSNorm → Self-Attention → Residual → Post-Attention RMSNorm → MLP → Residual
|
||
使用 Pre-Norm 架构(先归一化再进入子层),并将残差连接的计算融合到 RMSNorm 中以节省显存。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: Qwen3Config,
|
||
) -> None:
|
||
super().__init__()
|
||
self.self_attn = Qwen3Attention(
|
||
hidden_size=config.hidden_size,
|
||
num_heads=config.num_attention_heads,
|
||
num_kv_heads=config.num_key_value_heads,
|
||
max_position=config.max_position_embeddings,
|
||
rms_norm_eps=config.rms_norm_eps,
|
||
qkv_bias=getattr(config, 'attention_bias', True),
|
||
head_dim=getattr(config, 'head_dim', None),
|
||
rope_theta=getattr(config, "rope_theta", 1000000),
|
||
rope_scaling=getattr(config, "rope_scaling", None),
|
||
)
|
||
self.mlp = Qwen3MLP(
|
||
hidden_size=config.hidden_size,
|
||
intermediate_size=config.intermediate_size,
|
||
hidden_act=config.hidden_act,
|
||
)
|
||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
||
def forward(
|
||
self,
|
||
positions: torch.Tensor,
|
||
hidden_states: torch.Tensor,
|
||
residual: torch.Tensor | None,
|
||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
# 残差连接融合到 RMSNorm 中:residual = hidden_states + residual, output = RMSNorm(residual)
|
||
if residual is None:
|
||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||
else:
|
||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||
hidden_states = self.self_attn(positions, hidden_states)
|
||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||
hidden_states = self.mlp(hidden_states)
|
||
return hidden_states, residual
|
||
|
||
|
||
class Qwen3Model(nn.Module):
|
||
"""Qwen3 的 Transformer 主体:Embedding → N × DecoderLayer → Final RMSNorm。"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: Qwen3Config,
|
||
) -> None:
|
||
super().__init__()
|
||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
positions: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
hidden_states = self.embed_tokens(input_ids)
|
||
residual = None
|
||
for layer in self.layers:
|
||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||
# 最终的 RMSNorm 也融合了残差加法
|
||
hidden_states, _ = self.norm(hidden_states, residual)
|
||
return hidden_states
|
||
|
||
|
||
class Qwen3ForCausalLM(nn.Module):
|
||
"""Qwen3 因果语言模型:用于文本生成。
|
||
|
||
包含映射表,将 HuggingFace 的独立权重名(q_proj, k_proj, v_proj, gate_proj, up_proj)
|
||
映射到本项目融合后的权重名(qkv_proj, gate_up_proj),以便正确加载权重。
|
||
"""
|
||
|
||
# 融合模块的映射关系:HF 权重名 → (本项目模块名, shard_id)
|
||
packed_modules_mapping = {
|
||
"q_proj": ("qkv_proj", "q"),
|
||
"k_proj": ("qkv_proj", "k"),
|
||
"v_proj": ("qkv_proj", "v"),
|
||
"gate_proj": ("gate_up_proj", 0),
|
||
"up_proj": ("gate_up_proj", 1),
|
||
}
|
||
|
||
def __init__(
|
||
self,
|
||
config: Qwen3Config
|
||
) -> None:
|
||
super().__init__()
|
||
self.model = Qwen3Model(config)
|
||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||
# 如果模型配置了权重共享(tie),LM Head 和 Embedding 使用同一个权重矩阵
|
||
if config.tie_word_embeddings:
|
||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
positions: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
return self.model(input_ids, positions)
|
||
|
||
def compute_logits(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
"""将最后一层隐藏状态通过 LM Head 转换为词表 logits。"""
|
||
return self.lm_head(hidden_states)
|