add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -3,10 +3,21 @@ from torch import nn
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""采样器:将 logits 转换为 token ID。
|
||||
|
||||
使用 Gumbel-like 采样方法(而非标准的 top-k/top-p):
|
||||
1. 将 logits 除以温度(temperature)。
|
||||
2. 计算 softmax 得到概率分布。
|
||||
3. 用指数分布噪声扰动概率,取 argmax。
|
||||
|
||||
这种方法等价于从 softmax(logits/temperature) 分布中采样,
|
||||
但避免了逐元素随机选择的低效操作,全部用张量运算实现。
|
||||
"""
|
||||
|
||||
@torch.compile
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
# 指数分布噪声采样:probs / Exp(1) 的 argmax 等价于按 probs 概率采样
|
||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
||||
return sample_tokens
|
||||
|
||||
Reference in New Issue
Block a user