add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,12 @@ from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Module):
|
||||
"""词表并行 Embedding:将词表按 TP rank 切分。
|
||||
|
||||
每个 rank 只存储词表中属于自己的分片。前向计算时:
|
||||
1. 只查找属于自己的 token ID,其他位置输出零。
|
||||
2. 通过 all-reduce 聚合所有 rank 的结果。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -25,6 +31,7 @@ class VocabParallelEmbedding(nn.Module):
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""加载属于当前 rank 的词表分片。"""
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(0)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
@@ -33,16 +40,27 @@ class VocabParallelEmbedding(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
# 构造 mask:标记哪些 token ID 属于当前 rank 的范围
|
||||
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
||||
# 将全局 token ID 转为局部索引
|
||||
x = mask * (x - self.vocab_start_idx)
|
||||
y = F.embedding(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
# 非当前 rank 范围的 token 输出清零,然后 all-reduce 求和
|
||||
y = mask.unsqueeze(1) * y
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""并行 LM Head:将隐藏状态映射为词表 logits。
|
||||
|
||||
与 Embedding 共享权重(如果模型配置了 tie_word_embeddings),
|
||||
但前向逻辑不同:使用 F.linear(矩阵乘法)而非 F.embedding(查表)。
|
||||
|
||||
Prefill 阶段只需要每个序列最后一个 token 的 logits(因为只有最后一个 token 会用于采样),
|
||||
所以先用 cu_seqlens 提取最后位置,再做矩阵乘法,减少计算量。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -56,10 +74,13 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
def forward(self, x: torch.Tensor):
|
||||
context = get_context()
|
||||
if context.is_prefill:
|
||||
# Prefill: 只取每个序列最后一个 token 的隐藏状态用于采样
|
||||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
# 矩阵乘法得到 logits
|
||||
logits = F.linear(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
# Gather 所有 rank 的 logits 分片到 rank 0
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist.gather(logits, all_logits, 0)
|
||||
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||
|
||||
Reference in New Issue
Block a user