add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -5,11 +5,20 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""整除断言,确保张量并行时维度能被均匀切分。"""
|
||||
assert numerator % denominator == 0
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class LinearBase(nn.Module):
|
||||
"""所有并行线性层的基类。
|
||||
|
||||
Attributes:
|
||||
tp_dim: 张量并行切分的维度(0=列切分,1=行切分,None=不切分)。
|
||||
tp_rank: 当前进程在 TP 组中的 rank。
|
||||
tp_size: TP 组的总大小。
|
||||
weight: 权重参数,带有 weight_loader 方法用于加载预训练权重。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,6 +44,7 @@ class LinearBase(nn.Module):
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""复制式线性层:所有 TP rank 持有完整的权重副本。用于不需要切分的层。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -52,6 +62,11 @@ class ReplicatedLinear(LinearBase):
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""列并行线性层:将输出维度按 TP rank 切分。
|
||||
|
||||
每个 TP rank 持有输出维度的一个分片。例如输出维度为 4096,TP=2 时每个 rank 持有 2048。
|
||||
常用于 QKV 投影和 FFN 的 gate/up 投影(这些层的输出可以独立计算)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -63,6 +78,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""加载权重时按 tp_rank 切取对应的列分片。"""
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
@@ -74,6 +90,13 @@ class ColumnParallelLinear(LinearBase):
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""融合的列并行线性层:将多个线性层合并为一个矩阵乘法。
|
||||
|
||||
典型用途是将 gate_proj 和 up_proj 融合为 gate_up_proj,
|
||||
减少 kernel launch 次数,提升计算效率。
|
||||
|
||||
权重加载时需要根据 shard_id(子层索引)定位到正确的权重分片位置。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -85,7 +108,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
super().__init__(input_size, sum(output_sizes), bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||
"""根据 shard_id 将权重加载到融合矩阵的正确位置。"""
|
||||
param_data = param.data
|
||||
# 计算该子层在融合矩阵中的偏移量
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
@@ -94,6 +119,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""QKV 融合的列并行线性层。
|
||||
|
||||
将 Q、K、V 三个投影合并为一个矩阵乘法。
|
||||
权重按 [Q | K | V] 的顺序排列,加载时根据 shard_id("q"/"k"/"v")
|
||||
定位到对应的位置。
|
||||
|
||||
支持 GQA:Q 的 head 数和 KV 的 head 数可以不同。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -112,6 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
super().__init__(hidden_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||
"""根据 shard_id ("q"/"k"/"v") 将权重加载到融合矩阵的正确位置。"""
|
||||
param_data = param.data
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if loaded_shard_id == "q":
|
||||
@@ -120,7 +154,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
elif loaded_shard_id == "k":
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
else:
|
||||
else: # "v"
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
@@ -129,6 +163,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""行并行线性层:将输入维度按 TP rank 切分。
|
||||
|
||||
每个 TP rank 持有输入维度的一个分片。前向计算后需要 all-reduce
|
||||
将所有 rank 的结果求和,得到完整的输出。
|
||||
常用于 O 投影和 FFN 的 down 投影(这些层的输出需要跨 rank 聚合)。
|
||||
|
||||
偏置项只在 rank 0 添加,避免 all-reduce 后重复加 bias。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -140,8 +182,10 @@ class RowParallelLinear(LinearBase):
|
||||
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""加载权重时按 tp_rank 切取对应的行分片。"""
|
||||
param_data = param.data
|
||||
if param_data.ndim == 1:
|
||||
# bias 不切分,每个 rank 持有完整副本
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
@@ -150,7 +194,8 @@ class RowParallelLinear(LinearBase):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# 只有 rank 0 加 bias,避免 all-reduce 后重复
|
||||
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(y)
|
||||
dist.all_reduce(y) # 跨 rank 求和,得到完整输出
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user