Files
nano-vllm/nanovllm/layers/sampler.py
T
GeekExplorer 9e8507ef41 minor simplify
2026-04-13 22:09:46 +08:00

13 lines
407 B
Python

import torch
from torch import nn
class Sampler(nn.Module):
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens