import torch from torch import nn class Sampler(nn.Module): """采样器:将 logits 转换为 token ID。 使用 Gumbel-like 采样方法(而非标准的 top-k/top-p): 1. 将 logits 除以温度(temperature)。 2. 计算 softmax 得到概率分布。 3. 用指数分布噪声扰动概率,取 argmax。 这种方法等价于从 softmax(logits/temperature) 分布中采样, 但避免了逐元素随机选择的低效操作,全部用张量运算实现。 """ @torch.compile def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): logits = logits.float().div_(temperatures.unsqueeze(dim=1)) probs = torch.softmax(logits, dim=-1) # 指数分布噪声采样:probs / Exp(1) 的 argmax 等价于按 probs 概率采样 sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1) return sample_tokens