Files
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

202 lines
7.5 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
def divide(numerator, denominator):
"""整除断言,确保张量并行时维度能被均匀切分。"""
assert numerator % denominator == 0
return numerator // denominator
class LinearBase(nn.Module):
"""所有并行线性层的基类。
Attributes:
tp_dim: 张量并行切分的维度(0=列切分,1=行切分,None=不切分)。
tp_rank: 当前进程在 TP 组中的 rank。
tp_size: TP 组的总大小。
weight: 权重参数,带有 weight_loader 方法用于加载预训练权重。
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""复制式线性层:所有 TP rank 持有完整的权重副本。用于不需要切分的层。"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class ColumnParallelLinear(LinearBase):
"""列并行线性层:将输出维度按 TP rank 切分。
每个 TP rank 持有输出维度的一个分片。例如输出维度为 4096,TP=2 时每个 rank 持有 2048。
常用于 QKV 投影和 FFN 的 gate/up 投影(这些层的输出可以独立计算)。
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""加载权重时按 tp_rank 切取对应的列分片。"""
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class MergedColumnParallelLinear(ColumnParallelLinear):
"""融合的列并行线性层:将多个线性层合并为一个矩阵乘法。
典型用途是将 gate_proj 和 up_proj 融合为 gate_up_proj
减少 kernel launch 次数,提升计算效率。
权重加载时需要根据 shard_id(子层索引)定位到正确的权重分片位置。
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
"""根据 shard_id 将权重加载到融合矩阵的正确位置。"""
param_data = param.data
# 计算该子层在融合矩阵中的偏移量
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
"""QKV 融合的列并行线性层。
将 Q、K、V 三个投影合并为一个矩阵乘法。
权重按 [Q | K | V] 的顺序排列,加载时根据 shard_id("q"/"k"/"v"
定位到对应的位置。
支持 GQAQ 的 head 数和 KV 的 head 数可以不同。
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = dist.get_world_size()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
"""根据 shard_id ("q"/"k"/"v") 将权重加载到融合矩阵的正确位置。"""
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
shard_size = self.num_heads * self.head_size
shard_offset = 0
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else: # "v"
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
"""行并行线性层:将输入维度按 TP rank 切分。
每个 TP rank 持有输入维度的一个分片。前向计算后需要 all-reduce
将所有 rank 的结果求和,得到完整的输出。
常用于 O 投影和 FFN 的 down 投影(这些层的输出需要跨 rank 聚合)。
偏置项只在 rank 0 添加,避免 all-reduce 后重复加 bias。
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""加载权重时按 tp_rank 切取对应的行分片。"""
param_data = param.data
if param_data.ndim == 1:
# bias 不切分,每个 rank 持有完整副本
param_data.copy_(loaded_weight)
return
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 只有 rank 0 加 bias,避免 all-reduce 后重复
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y) # 跨 rank 求和,得到完整输出
return y