import torch from torch import nn class RMSNorm(nn.Module): """Root Mean Square Layer Normalization(RMSNorm)。 与 LayerNorm 相比,RMSNorm 不计算均值,只计算均方根,计算量更小。 公式: output = x / sqrt(mean(x^2) + eps) * weight 提供两个前向路径: - rms_forward: 标准 RMSNorm。 - add_rms_forward: 将残差加法融合到 RMSNorm 中(x + residual → RMSNorm), 减少一次显存读写,是 vLLM 等推理框架的常见优化。 """ def __init__( self, hidden_size: int, eps: float = 1e-6, ) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) @torch.compile def rms_forward( self, x: torch.Tensor, ) -> torch.Tensor: orig_dtype = x.dtype x = x.float() var = x.pow(2).mean(dim=-1, keepdim=True) x.mul_(torch.rsqrt(var + self.eps)) x = x.to(orig_dtype).mul_(self.weight) return x @torch.compile def add_rms_forward( self, x: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """融合残差加法和 RMSNorm。 Returns: (normalized_output, updated_residual): 归一化后的输出和更新后的残差(= x + residual)。 """ orig_dtype = x.dtype x = x.float().add_(residual.float()) residual = x.to(orig_dtype) var = x.pow(2).mean(dim=-1, keepdim=True) x.mul_(torch.rsqrt(var + self.eps)) x = x.to(orig_dtype).mul_(self.weight) return x, residual def forward( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is None: return self.rms_forward(x) else: return self.add_rms_forward(x, residual)