import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist def divide(numerator, denominator): """整除断言,确保张量并行时维度能被均匀切分。""" assert numerator % denominator == 0 return numerator // denominator class LinearBase(nn.Module): """所有并行线性层的基类。 Attributes: tp_dim: 张量并行切分的维度(0=列切分,1=行切分,None=不切分)。 tp_rank: 当前进程在 TP 组中的 rank。 tp_size: TP 组的总大小。 weight: 权重参数,带有 weight_loader 方法用于加载预训练权重。 """ def __init__( self, input_size: int, output_size: int, bias: bool = False, tp_dim: int | None = None, ): super().__init__() self.tp_dim = tp_dim self.tp_rank = dist.get_rank() self.tp_size = dist.get_world_size() self.weight = nn.Parameter(torch.empty(output_size, input_size)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(output_size)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError class ReplicatedLinear(LinearBase): """复制式线性层:所有 TP rank 持有完整的权重副本。用于不需要切分的层。""" def __init__( self, input_size: int, output_size: int, bias: bool = False, ): super().__init__(input_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class ColumnParallelLinear(LinearBase): """列并行线性层:将输出维度按 TP rank 切分。 每个 TP rank 持有输出维度的一个分片。例如输出维度为 4096,TP=2 时每个 rank 持有 2048。 常用于 QKV 投影和 FFN 的 gate/up 投影(这些层的输出可以独立计算)。 """ def __init__( self, input_size: int, output_size: int, bias: bool = False, ): tp_size = dist.get_world_size() super().__init__(input_size, divide(output_size, tp_size), bias, 0) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): """加载权重时按 tp_rank 切取对应的列分片。""" param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class MergedColumnParallelLinear(ColumnParallelLinear): """融合的列并行线性层:将多个线性层合并为一个矩阵乘法。 典型用途是将 gate_proj 和 up_proj 融合为 gate_up_proj, 减少 kernel launch 次数,提升计算效率。 权重加载时需要根据 shard_id(子层索引)定位到正确的权重分片位置。 """ def __init__( self, input_size: int, output_sizes: list[int], bias: bool = False, ): self.output_sizes = output_sizes super().__init__(input_size, sum(output_sizes), bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): """根据 shard_id 将权重加载到融合矩阵的正确位置。""" param_data = param.data # 计算该子层在融合矩阵中的偏移量 shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] param_data.copy_(loaded_weight) class QKVParallelLinear(ColumnParallelLinear): """QKV 融合的列并行线性层。 将 Q、K、V 三个投影合并为一个矩阵乘法。 权重按 [Q | K | V] 的顺序排列,加载时根据 shard_id("q"/"k"/"v") 定位到对应的位置。 支持 GQA:Q 的 head 数和 KV 的 head 数可以不同。 """ def __init__( self, hidden_size: int, head_size: int, total_num_heads: int, total_num_kv_heads: int | None = None, bias: bool = False, ): tp_size = dist.get_world_size() total_num_kv_heads = total_num_kv_heads or total_num_heads self.head_size = head_size self.num_heads = divide(total_num_heads, tp_size) self.num_kv_heads = divide(total_num_kv_heads, tp_size) output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size super().__init__(hidden_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): """根据 shard_id ("q"/"k"/"v") 将权重加载到融合矩阵的正确位置。""" param_data = param.data assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": shard_size = self.num_heads * self.head_size shard_offset = 0 elif loaded_shard_id == "k": shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size else: # "v" shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] param_data.copy_(loaded_weight) class RowParallelLinear(LinearBase): """行并行线性层:将输入维度按 TP rank 切分。 每个 TP rank 持有输入维度的一个分片。前向计算后需要 all-reduce 将所有 rank 的结果求和,得到完整的输出。 常用于 O 投影和 FFN 的 down 投影(这些层的输出需要跨 rank 聚合)。 偏置项只在 rank 0 添加,避免 all-reduce 后重复加 bias。 """ def __init__( self, input_size: int, output_size: int, bias: bool = False, ): tp_size = dist.get_world_size() super().__init__(divide(input_size, tp_size), output_size, bias, 1) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): """加载权重时按 tp_rank 切取对应的行分片。""" param_data = param.data if param_data.ndim == 1: # bias 不切分,每个 rank 持有完整副本 param_data.copy_(loaded_weight) return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: # 只有 rank 0 加 bias,避免 all-reduce 后重复 y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None) if self.tp_size > 1: dist.all_reduce(y) # 跨 rank 求和,得到完整输出 return y