ffd2defdfc
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
24 lines
950 B
Python
24 lines
950 B
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
"""采样器:将 logits 转换为 token ID。
|
|
|
|
使用 Gumbel-like 采样方法(而非标准的 top-k/top-p):
|
|
1. 将 logits 除以温度(temperature)。
|
|
2. 计算 softmax 得到概率分布。
|
|
3. 用指数分布噪声扰动概率,取 argmax。
|
|
|
|
这种方法等价于从 softmax(logits/temperature) 分布中采样,
|
|
但避免了逐元素随机选择的低效操作,全部用张量运算实现。
|
|
"""
|
|
|
|
@torch.compile
|
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
|
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
|
probs = torch.softmax(logits, dim=-1)
|
|
# 指数分布噪声采样:probs / Exp(1) 的 argmax 等价于按 probs 概率采样
|
|
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
|
return sample_tokens
|