add Chinese annotations to all source files for learning purposes

Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 21:33:15 +08:00
parent bb823b3e06
commit ffd2defdfc
19 changed files with 656 additions and 34 deletions
+6
View File
@@ -4,6 +4,12 @@ import torch.nn.functional as F
class SiluAndMul(nn.Module):
"""SwiGLU 激活函数:SiLU(gate) * up。
输入是 gate 和 up 拼接的张量,沿最后一维一分为二,
对前半部分应用 SiLU 激活后与后半部分逐元素相乘。
这是 LLaMA/Qwen 系列模型中 MLP 层的标准激活函数。
"""
@torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
+33 -4
View File
@@ -18,6 +18,12 @@ def store_kvcache_kernel(
slot_mapping_ptr,
D: tl.constexpr,
):
"""Triton kernel:将新计算的 K/V 向量写入 KV cache。
每个 Triton program 处理一个 token 的 K/V 写入。
slot_mapping 指定了该 token 的 K/V 应该写入 cache 的哪个 slot。
slot == -1 表示该 token 不需要写入(如 warmup 阶段)。
"""
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
@@ -31,6 +37,15 @@ def store_kvcache_kernel(
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
"""将新计算的 K/V 向量存储到 KV cache 中。
Args:
key: [N, num_heads, head_dim] 新计算的 K 向量。
value: [N, num_heads, head_dim] 新计算的 V 向量。
k_cache: KV cache 中 K 的存储区域。
v_cache: KV cache 中 V 的存储区域。
slot_mapping: [N] 每个 token 对应的 cache slot 索引。
"""
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
@@ -41,6 +56,18 @@ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor,
class Attention(nn.Module):
"""注意力层:封装了 Flash Attention 和 KV cache 的交互逻辑。
支持 two 阶段的注意力计算:
- Prefill: 使用 flash_attn_varlen_func(变长序列批量注意力),一次性处理整个 prompt。
- Decode: 使用 flash_attn_with_kvcache(带 KV cache 的注意力),逐 token 生成。
当存在前缀缓存时(block_tables 不为 None),prefill 阶段直接从 KV cache 读取
已缓存的 K/V,而不是用当前计算的 K/V。
Attributes:
k_cache, v_cache: 绑定到 ModelRunner 分配的全局 KV cache 的对应层视图。
"""
def __init__(
self,
@@ -54,22 +81,24 @@ class Attention(nn.Module):
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
self.k_cache = self.v_cache = torch.tensor([]) # 占位,由 ModelRunner 分配后绑定
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
# 将新计算的 K/V 写入 KV cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # prefix cache
if context.block_tables is not None: # 前缀缓存命中:从 KV cache 读取 K/V
k, v = k_cache, v_cache
# Flash Attention 变长版:支持不同长度的序列在同一批次中计算
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
else: # Decode 阶段:从 KV cache 中读取所有历史 K/V
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o
+21
View File
@@ -7,6 +7,12 @@ from nanovllm.utils.context import get_context
class VocabParallelEmbedding(nn.Module):
"""词表并行 Embedding:将词表按 TP rank 切分。
每个 rank 只存储词表中属于自己的分片。前向计算时:
1. 只查找属于自己的 token ID,其他位置输出零。
2. 通过 all-reduce 聚合所有 rank 的结果。
"""
def __init__(
self,
@@ -25,6 +31,7 @@ class VocabParallelEmbedding(nn.Module):
self.weight.weight_loader = self.weight_loader
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""加载属于当前 rank 的词表分片。"""
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
@@ -33,16 +40,27 @@ class VocabParallelEmbedding(nn.Module):
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
# 构造 mask:标记哪些 token ID 属于当前 rank 的范围
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
# 将全局 token ID 转为局部索引
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
# 非当前 rank 范围的 token 输出清零,然后 all-reduce 求和
y = mask.unsqueeze(1) * y
dist.all_reduce(y)
return y
class ParallelLMHead(VocabParallelEmbedding):
"""并行 LM Head:将隐藏状态映射为词表 logits。
与 Embedding 共享权重(如果模型配置了 tie_word_embeddings),
但前向逻辑不同:使用 F.linear(矩阵乘法)而非 F.embedding(查表)。
Prefill 阶段只需要每个序列最后一个 token 的 logits(因为只有最后一个 token 会用于采样),
所以先用 cu_seqlens 提取最后位置,再做矩阵乘法,减少计算量。
"""
def __init__(
self,
@@ -56,10 +74,13 @@ class ParallelLMHead(VocabParallelEmbedding):
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
# Prefill: 只取每个序列最后一个 token 的隐藏状态用于采样
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
# 矩阵乘法得到 logits
logits = F.linear(x, self.weight)
if self.tp_size > 1:
# Gather 所有 rank 的 logits 分片到 rank 0
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
+15
View File
@@ -3,6 +3,16 @@ from torch import nn
class RMSNorm(nn.Module):
"""Root Mean Square Layer NormalizationRMSNorm)。
与 LayerNorm 相比,RMSNorm 不计算均值,只计算均方根,计算量更小。
公式: output = x / sqrt(mean(x^2) + eps) * weight
提供两个前向路径:
- rms_forward: 标准 RMSNorm。
- add_rms_forward: 将残差加法融合到 RMSNorm 中(x + residual → RMSNorm),
减少一次显存读写,是 vLLM 等推理框架的常见优化。
"""
def __init__(
self,
@@ -31,6 +41,11 @@ class RMSNorm(nn.Module):
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""融合残差加法和 RMSNorm。
Returns:
(normalized_output, updated_residual): 归一化后的输出和更新后的残差(= x + residual)。
"""
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
+47 -2
View File
@@ -5,11 +5,20 @@ import torch.distributed as dist
def divide(numerator, denominator):
"""整除断言,确保张量并行时维度能被均匀切分。"""
assert numerator % denominator == 0
return numerator // denominator
class LinearBase(nn.Module):
"""所有并行线性层的基类。
Attributes:
tp_dim: 张量并行切分的维度(0=列切分,1=行切分,None=不切分)。
tp_rank: 当前进程在 TP 组中的 rank。
tp_size: TP 组的总大小。
weight: 权重参数,带有 weight_loader 方法用于加载预训练权重。
"""
def __init__(
self,
@@ -35,6 +44,7 @@ class LinearBase(nn.Module):
class ReplicatedLinear(LinearBase):
"""复制式线性层:所有 TP rank 持有完整的权重副本。用于不需要切分的层。"""
def __init__(
self,
@@ -52,6 +62,11 @@ class ReplicatedLinear(LinearBase):
class ColumnParallelLinear(LinearBase):
"""列并行线性层:将输出维度按 TP rank 切分。
每个 TP rank 持有输出维度的一个分片。例如输出维度为 4096,TP=2 时每个 rank 持有 2048。
常用于 QKV 投影和 FFN 的 gate/up 投影(这些层的输出可以独立计算)。
"""
def __init__(
self,
@@ -63,6 +78,7 @@ class ColumnParallelLinear(LinearBase):
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""加载权重时按 tp_rank 切取对应的列分片。"""
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
@@ -74,6 +90,13 @@ class ColumnParallelLinear(LinearBase):
class MergedColumnParallelLinear(ColumnParallelLinear):
"""融合的列并行线性层:将多个线性层合并为一个矩阵乘法。
典型用途是将 gate_proj 和 up_proj 融合为 gate_up_proj
减少 kernel launch 次数,提升计算效率。
权重加载时需要根据 shard_id(子层索引)定位到正确的权重分片位置。
"""
def __init__(
self,
@@ -85,7 +108,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
"""根据 shard_id 将权重加载到融合矩阵的正确位置。"""
param_data = param.data
# 计算该子层在融合矩阵中的偏移量
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
@@ -94,6 +119,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
class QKVParallelLinear(ColumnParallelLinear):
"""QKV 融合的列并行线性层。
将 Q、K、V 三个投影合并为一个矩阵乘法。
权重按 [Q | K | V] 的顺序排列,加载时根据 shard_id("q"/"k"/"v"
定位到对应的位置。
支持 GQAQ 的 head 数和 KV 的 head 数可以不同。
"""
def __init__(
self,
@@ -112,6 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear):
super().__init__(hidden_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
"""根据 shard_id ("q"/"k"/"v") 将权重加载到融合矩阵的正确位置。"""
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
@@ -120,7 +154,7 @@ class QKVParallelLinear(ColumnParallelLinear):
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else:
else: # "v"
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
@@ -129,6 +163,14 @@ class QKVParallelLinear(ColumnParallelLinear):
class RowParallelLinear(LinearBase):
"""行并行线性层:将输入维度按 TP rank 切分。
每个 TP rank 持有输入维度的一个分片。前向计算后需要 all-reduce
将所有 rank 的结果求和,得到完整的输出。
常用于 O 投影和 FFN 的 down 投影(这些层的输出需要跨 rank 聚合)。
偏置项只在 rank 0 添加,避免 all-reduce 后重复加 bias。
"""
def __init__(
self,
@@ -140,8 +182,10 @@ class RowParallelLinear(LinearBase):
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""加载权重时按 tp_rank 切取对应的行分片。"""
param_data = param.data
if param_data.ndim == 1:
# bias 不切分,每个 rank 持有完整副本
param_data.copy_(loaded_weight)
return
shard_size = param_data.size(self.tp_dim)
@@ -150,7 +194,8 @@ class RowParallelLinear(LinearBase):
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 只有 rank 0 加 bias,避免 all-reduce 后重复
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y)
dist.all_reduce(y) # 跨 rank 求和,得到完整输出
return y
+25
View File
@@ -8,6 +8,14 @@ def apply_rotary_emb(
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""应用旋转位置编码(RoPE)。
将向量沿最后一维分成两半 (x1, x2),然后做旋转变换:
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
这等价于在二维平面上将每个相邻的 (x1, x2) 对旋转 theta 角度,
其中 theta = position / (base^(2i/d))。
"""
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
@@ -15,6 +23,15 @@ def apply_rotary_emb(
class RotaryEmbedding(nn.Module):
"""旋转位置编码(Rotary Position Embedding, RoPE)。
RoPE 通过旋转矩阵编码位置信息,使得注意力计算中
内积只依赖相对位置(q_i · k_j 只与 i-j 有关),
从而天然支持外推到更长序列。
预计算所有位置的 cos 和 sin 值并缓存,避免重复计算。
缓存形状: [max_position_embeddings, 1, head_dim],中间维是 num_heads 的广播维度。
"""
def __init__(
self,
@@ -26,11 +43,14 @@ class RotaryEmbedding(nn.Module):
super().__init__()
self.head_size = head_size
assert rotary_dim == head_size
# 计算逆频率: 1 / (base^(2i/d)), i = 0, 1, ..., d/2-1
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
# 计算所有位置的频率: pos * inv_freq
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
# 拼接 cos 和 sin,形状 [max_pos, 1, head_dim]
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
@@ -41,6 +61,7 @@ class RotaryEmbedding(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""根据位置索引查找 cos/sin 并应用到 Q 和 K。"""
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query = apply_rotary_emb(query, cos, sin)
@@ -55,5 +76,9 @@ def get_rope(
max_position: int,
base: float,
):
"""获取 RotaryEmbedding 的单例实例。
使用 lru_cache 确保相同参数只创建一次。
"""
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb
+11
View File
@@ -3,10 +3,21 @@ from torch import nn
class Sampler(nn.Module):
"""采样器:将 logits 转换为 token ID。
使用 Gumbel-like 采样方法(而非标准的 top-k/top-p):
1. 将 logits 除以温度(temperature)。
2. 计算 softmax 得到概率分布。
3. 用指数分布噪声扰动概率,取 argmax。
这种方法等价于从 softmax(logits/temperature) 分布中采样,
但避免了逐元素随机选择的低效操作,全部用张量运算实现。
"""
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
# 指数分布噪声采样:probs / Exp(1) 的 argmax 等价于按 probs 概率采样
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens