Files
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

66 lines
1.9 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from torch import nn
class RMSNorm(nn.Module):
"""Root Mean Square Layer NormalizationRMSNorm)。
与 LayerNorm 相比,RMSNorm 不计算均值,只计算均方根,计算量更小。
公式: output = x / sqrt(mean(x^2) + eps) * weight
提供两个前向路径:
- rms_forward: 标准 RMSNorm。
- add_rms_forward: 将残差加法融合到 RMSNorm 中(x + residual → RMSNorm),
减少一次显存读写,是 vLLM 等推理框架的常见优化。
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
@torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x
@torch.compile
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""融合残差加法和 RMSNorm。
Returns:
(normalized_output, updated_residual): 归一化后的输出和更新后的残差(= x + residual)。
"""
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)