add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,16 @@ from torch import nn
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""Root Mean Square Layer Normalization(RMSNorm)。
|
||||
|
||||
与 LayerNorm 相比,RMSNorm 不计算均值,只计算均方根,计算量更小。
|
||||
公式: output = x / sqrt(mean(x^2) + eps) * weight
|
||||
|
||||
提供两个前向路径:
|
||||
- rms_forward: 标准 RMSNorm。
|
||||
- add_rms_forward: 将残差加法融合到 RMSNorm 中(x + residual → RMSNorm),
|
||||
减少一次显存读写,是 vLLM 等推理框架的常见优化。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -31,6 +41,11 @@ class RMSNorm(nn.Module):
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""融合残差加法和 RMSNorm。
|
||||
|
||||
Returns:
|
||||
(normalized_output, updated_residual): 归一化后的输出和更新后的残差(= x + residual)。
|
||||
"""
|
||||
orig_dtype = x.dtype
|
||||
x = x.float().add_(residual.float())
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
Reference in New Issue
Block a user