add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,12 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""SwiGLU 激活函数:SiLU(gate) * up。
|
||||
|
||||
输入是 gate 和 up 拼接的张量,沿最后一维一分为二,
|
||||
对前半部分应用 SiLU 激活后与后半部分逐元素相乘。
|
||||
这是 LLaMA/Qwen 系列模型中 MLP 层的标准激活函数。
|
||||
"""
|
||||
|
||||
@torch.compile
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -18,6 +18,12 @@ def store_kvcache_kernel(
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
):
|
||||
"""Triton kernel:将新计算的 K/V 向量写入 KV cache。
|
||||
|
||||
每个 Triton program 处理一个 token 的 K/V 写入。
|
||||
slot_mapping 指定了该 token 的 K/V 应该写入 cache 的哪个 slot。
|
||||
slot == -1 表示该 token 不需要写入(如 warmup 阶段)。
|
||||
"""
|
||||
idx = tl.program_id(0)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
if slot == -1: return
|
||||
@@ -31,6 +37,15 @@ def store_kvcache_kernel(
|
||||
|
||||
|
||||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||||
"""将新计算的 K/V 向量存储到 KV cache 中。
|
||||
|
||||
Args:
|
||||
key: [N, num_heads, head_dim] 新计算的 K 向量。
|
||||
value: [N, num_heads, head_dim] 新计算的 V 向量。
|
||||
k_cache: KV cache 中 K 的存储区域。
|
||||
v_cache: KV cache 中 V 的存储区域。
|
||||
slot_mapping: [N] 每个 token 对应的 cache slot 索引。
|
||||
"""
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
@@ -41,6 +56,18 @@ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor,
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""注意力层:封装了 Flash Attention 和 KV cache 的交互逻辑。
|
||||
|
||||
支持 two 阶段的注意力计算:
|
||||
- Prefill: 使用 flash_attn_varlen_func(变长序列批量注意力),一次性处理整个 prompt。
|
||||
- Decode: 使用 flash_attn_with_kvcache(带 KV cache 的注意力),逐 token 生成。
|
||||
|
||||
当存在前缀缓存时(block_tables 不为 None),prefill 阶段直接从 KV cache 读取
|
||||
已缓存的 K/V,而不是用当前计算的 K/V。
|
||||
|
||||
Attributes:
|
||||
k_cache, v_cache: 绑定到 ModelRunner 分配的全局 KV cache 的对应层视图。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -54,22 +81,24 @@ class Attention(nn.Module):
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
self.k_cache = self.v_cache = torch.tensor([]) # 占位,由 ModelRunner 分配后绑定
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
context = get_context()
|
||||
k_cache, v_cache = self.k_cache, self.v_cache
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
# 将新计算的 K/V 写入 KV cache
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None: # prefix cache
|
||||
if context.block_tables is not None: # 前缀缓存命中:从 KV cache 读取 K/V
|
||||
k, v = k_cache, v_cache
|
||||
# Flash Attention 变长版:支持不同长度的序列在同一批次中计算
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
else: # Decode 阶段:从 KV cache 中读取所有历史 K/V
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
return o
|
||||
|
||||
@@ -7,6 +7,12 @@ from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Module):
|
||||
"""词表并行 Embedding:将词表按 TP rank 切分。
|
||||
|
||||
每个 rank 只存储词表中属于自己的分片。前向计算时:
|
||||
1. 只查找属于自己的 token ID,其他位置输出零。
|
||||
2. 通过 all-reduce 聚合所有 rank 的结果。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -25,6 +31,7 @@ class VocabParallelEmbedding(nn.Module):
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""加载属于当前 rank 的词表分片。"""
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(0)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
@@ -33,16 +40,27 @@ class VocabParallelEmbedding(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
# 构造 mask:标记哪些 token ID 属于当前 rank 的范围
|
||||
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
||||
# 将全局 token ID 转为局部索引
|
||||
x = mask * (x - self.vocab_start_idx)
|
||||
y = F.embedding(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
# 非当前 rank 范围的 token 输出清零,然后 all-reduce 求和
|
||||
y = mask.unsqueeze(1) * y
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""并行 LM Head:将隐藏状态映射为词表 logits。
|
||||
|
||||
与 Embedding 共享权重(如果模型配置了 tie_word_embeddings),
|
||||
但前向逻辑不同:使用 F.linear(矩阵乘法)而非 F.embedding(查表)。
|
||||
|
||||
Prefill 阶段只需要每个序列最后一个 token 的 logits(因为只有最后一个 token 会用于采样),
|
||||
所以先用 cu_seqlens 提取最后位置,再做矩阵乘法,减少计算量。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -56,10 +74,13 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
def forward(self, x: torch.Tensor):
|
||||
context = get_context()
|
||||
if context.is_prefill:
|
||||
# Prefill: 只取每个序列最后一个 token 的隐藏状态用于采样
|
||||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
# 矩阵乘法得到 logits
|
||||
logits = F.linear(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
# Gather 所有 rank 的 logits 分片到 rank 0
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist.gather(logits, all_logits, 0)
|
||||
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||
|
||||
@@ -3,6 +3,16 @@ from torch import nn
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""Root Mean Square Layer Normalization(RMSNorm)。
|
||||
|
||||
与 LayerNorm 相比,RMSNorm 不计算均值,只计算均方根,计算量更小。
|
||||
公式: output = x / sqrt(mean(x^2) + eps) * weight
|
||||
|
||||
提供两个前向路径:
|
||||
- rms_forward: 标准 RMSNorm。
|
||||
- add_rms_forward: 将残差加法融合到 RMSNorm 中(x + residual → RMSNorm),
|
||||
减少一次显存读写,是 vLLM 等推理框架的常见优化。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -31,6 +41,11 @@ class RMSNorm(nn.Module):
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""融合残差加法和 RMSNorm。
|
||||
|
||||
Returns:
|
||||
(normalized_output, updated_residual): 归一化后的输出和更新后的残差(= x + residual)。
|
||||
"""
|
||||
orig_dtype = x.dtype
|
||||
x = x.float().add_(residual.float())
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
@@ -5,11 +5,20 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""整除断言,确保张量并行时维度能被均匀切分。"""
|
||||
assert numerator % denominator == 0
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class LinearBase(nn.Module):
|
||||
"""所有并行线性层的基类。
|
||||
|
||||
Attributes:
|
||||
tp_dim: 张量并行切分的维度(0=列切分,1=行切分,None=不切分)。
|
||||
tp_rank: 当前进程在 TP 组中的 rank。
|
||||
tp_size: TP 组的总大小。
|
||||
weight: 权重参数,带有 weight_loader 方法用于加载预训练权重。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,6 +44,7 @@ class LinearBase(nn.Module):
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""复制式线性层:所有 TP rank 持有完整的权重副本。用于不需要切分的层。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -52,6 +62,11 @@ class ReplicatedLinear(LinearBase):
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""列并行线性层:将输出维度按 TP rank 切分。
|
||||
|
||||
每个 TP rank 持有输出维度的一个分片。例如输出维度为 4096,TP=2 时每个 rank 持有 2048。
|
||||
常用于 QKV 投影和 FFN 的 gate/up 投影(这些层的输出可以独立计算)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -63,6 +78,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""加载权重时按 tp_rank 切取对应的列分片。"""
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
@@ -74,6 +90,13 @@ class ColumnParallelLinear(LinearBase):
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""融合的列并行线性层:将多个线性层合并为一个矩阵乘法。
|
||||
|
||||
典型用途是将 gate_proj 和 up_proj 融合为 gate_up_proj,
|
||||
减少 kernel launch 次数,提升计算效率。
|
||||
|
||||
权重加载时需要根据 shard_id(子层索引)定位到正确的权重分片位置。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -85,7 +108,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
super().__init__(input_size, sum(output_sizes), bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||
"""根据 shard_id 将权重加载到融合矩阵的正确位置。"""
|
||||
param_data = param.data
|
||||
# 计算该子层在融合矩阵中的偏移量
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
@@ -94,6 +119,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""QKV 融合的列并行线性层。
|
||||
|
||||
将 Q、K、V 三个投影合并为一个矩阵乘法。
|
||||
权重按 [Q | K | V] 的顺序排列,加载时根据 shard_id("q"/"k"/"v")
|
||||
定位到对应的位置。
|
||||
|
||||
支持 GQA:Q 的 head 数和 KV 的 head 数可以不同。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -112,6 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
super().__init__(hidden_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||
"""根据 shard_id ("q"/"k"/"v") 将权重加载到融合矩阵的正确位置。"""
|
||||
param_data = param.data
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if loaded_shard_id == "q":
|
||||
@@ -120,7 +154,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
elif loaded_shard_id == "k":
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
else:
|
||||
else: # "v"
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
@@ -129,6 +163,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""行并行线性层:将输入维度按 TP rank 切分。
|
||||
|
||||
每个 TP rank 持有输入维度的一个分片。前向计算后需要 all-reduce
|
||||
将所有 rank 的结果求和,得到完整的输出。
|
||||
常用于 O 投影和 FFN 的 down 投影(这些层的输出需要跨 rank 聚合)。
|
||||
|
||||
偏置项只在 rank 0 添加,避免 all-reduce 后重复加 bias。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -140,8 +182,10 @@ class RowParallelLinear(LinearBase):
|
||||
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""加载权重时按 tp_rank 切取对应的行分片。"""
|
||||
param_data = param.data
|
||||
if param_data.ndim == 1:
|
||||
# bias 不切分,每个 rank 持有完整副本
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
@@ -150,7 +194,8 @@ class RowParallelLinear(LinearBase):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# 只有 rank 0 加 bias,避免 all-reduce 后重复
|
||||
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(y)
|
||||
dist.all_reduce(y) # 跨 rank 求和,得到完整输出
|
||||
return y
|
||||
|
||||
@@ -8,6 +8,14 @@ def apply_rotary_emb(
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""应用旋转位置编码(RoPE)。
|
||||
|
||||
将向量沿最后一维分成两半 (x1, x2),然后做旋转变换:
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
这等价于在二维平面上将每个相邻的 (x1, x2) 对旋转 theta 角度,
|
||||
其中 theta = position / (base^(2i/d))。
|
||||
"""
|
||||
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
@@ -15,6 +23,15 @@ def apply_rotary_emb(
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""旋转位置编码(Rotary Position Embedding, RoPE)。
|
||||
|
||||
RoPE 通过旋转矩阵编码位置信息,使得注意力计算中
|
||||
内积只依赖相对位置(q_i · k_j 只与 i-j 有关),
|
||||
从而天然支持外推到更长序列。
|
||||
|
||||
预计算所有位置的 cos 和 sin 值并缓存,避免重复计算。
|
||||
缓存形状: [max_position_embeddings, 1, head_dim],中间维是 num_heads 的广播维度。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -26,11 +43,14 @@ class RotaryEmbedding(nn.Module):
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
assert rotary_dim == head_size
|
||||
# 计算逆频率: 1 / (base^(2i/d)), i = 0, 1, ..., d/2-1
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
# 计算所有位置的频率: pos * inv_freq
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
# 拼接 cos 和 sin,形状 [max_pos, 1, head_dim]
|
||||
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@@ -41,6 +61,7 @@ class RotaryEmbedding(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""根据位置索引查找 cos/sin 并应用到 Q 和 K。"""
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query = apply_rotary_emb(query, cos, sin)
|
||||
@@ -55,5 +76,9 @@ def get_rope(
|
||||
max_position: int,
|
||||
base: float,
|
||||
):
|
||||
"""获取 RotaryEmbedding 的单例实例。
|
||||
|
||||
使用 lru_cache 确保相同参数只创建一次。
|
||||
"""
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
return rotary_emb
|
||||
|
||||
@@ -3,10 +3,21 @@ from torch import nn
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""采样器:将 logits 转换为 token ID。
|
||||
|
||||
使用 Gumbel-like 采样方法(而非标准的 top-k/top-p):
|
||||
1. 将 logits 除以温度(temperature)。
|
||||
2. 计算 softmax 得到概率分布。
|
||||
3. 用指数分布噪声扰动概率,取 argmax。
|
||||
|
||||
这种方法等价于从 softmax(logits/temperature) 分布中采样,
|
||||
但避免了逐元素随机选择的低效操作,全部用张量运算实现。
|
||||
"""
|
||||
|
||||
@torch.compile
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
# 指数分布噪声采样:probs / Exp(1) 的 argmax 等价于按 probs 概率采样
|
||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
||||
return sample_tokens
|
||||
|
||||
Reference in New Issue
Block a user