Files
nano-vllm/nanovllm/layers/sampler.py
T
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

24 lines
950 B
Python

import torch
from torch import nn
class Sampler(nn.Module):
"""采样器:将 logits 转换为 token ID。
使用 Gumbel-like 采样方法(而非标准的 top-k/top-p):
1. 将 logits 除以温度(temperature)。
2. 计算 softmax 得到概率分布。
3. 用指数分布噪声扰动概率,取 argmax。
这种方法等价于从 softmax(logits/temperature) 分布中采样,
但避免了逐元素随机选择的低效操作,全部用张量运算实现。
"""
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
# 指数分布噪声采样:probs / Exp(1) 的 argmax 等价于按 probs 概率采样
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens