ffd2defdfc
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
66 lines
1.9 KiB
Python
Executable File
66 lines
1.9 KiB
Python
Executable File
import torch
|
||
from torch import nn
|
||
|
||
|
||
class RMSNorm(nn.Module):
|
||
"""Root Mean Square Layer Normalization(RMSNorm)。
|
||
|
||
与 LayerNorm 相比,RMSNorm 不计算均值,只计算均方根,计算量更小。
|
||
公式: output = x / sqrt(mean(x^2) + eps) * weight
|
||
|
||
提供两个前向路径:
|
||
- rms_forward: 标准 RMSNorm。
|
||
- add_rms_forward: 将残差加法融合到 RMSNorm 中(x + residual → RMSNorm),
|
||
减少一次显存读写,是 vLLM 等推理框架的常见优化。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
hidden_size: int,
|
||
eps: float = 1e-6,
|
||
) -> None:
|
||
super().__init__()
|
||
self.eps = eps
|
||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||
|
||
@torch.compile
|
||
def rms_forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
orig_dtype = x.dtype
|
||
x = x.float()
|
||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||
x.mul_(torch.rsqrt(var + self.eps))
|
||
x = x.to(orig_dtype).mul_(self.weight)
|
||
return x
|
||
|
||
@torch.compile
|
||
def add_rms_forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
residual: torch.Tensor,
|
||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
"""融合残差加法和 RMSNorm。
|
||
|
||
Returns:
|
||
(normalized_output, updated_residual): 归一化后的输出和更新后的残差(= x + residual)。
|
||
"""
|
||
orig_dtype = x.dtype
|
||
x = x.float().add_(residual.float())
|
||
residual = x.to(orig_dtype)
|
||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||
x.mul_(torch.rsqrt(var + self.eps))
|
||
x = x.to(orig_dtype).mul_(self.weight)
|
||
return x, residual
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
residual: torch.Tensor | None = None,
|
||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||
if residual is None:
|
||
return self.rms_forward(x)
|
||
else:
|
||
return self.add_rms_forward(x, residual)
|