Files
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

359 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pickle
import torch
import torch.distributed as dist
from multiprocessing.synchronize import Event
from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence
from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
class ModelRunner:
"""模型运行器:负责模型推理、KV cache 管理、CUDA Graph 捕获和张量并行通信。
在张量并行(TP)模式下:
- Rank 0 是主进程,负责采样和与引擎通信。
- Rank > 0 是工作进程,通过共享内存(SharedMemory)接收指令。
- 所有进程共享同一个模型和 KV cache 的分片。
生命周期:
1. 初始化: 加载模型 → warmup → 分配 KV cache → (可选)捕获 CUDA Graph
2. 推理: 接收序列 → 准备输入 → 运行模型 → 采样 token
3. 退出: 释放资源
"""
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config
hf_config = config.hf_config
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
self.rank = rank
self.event = event
# 初始化分布式进程组(NCCL 后端),所有 GPU 通过 TCP 通信
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
# 加载模型权重
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = Sampler()
# Warmup: 运行一次前向传播以确定模型本身的显存占用
self.warmup_model()
# 根据剩余显存分配 KV cache
self.allocate_kv_cache()
# 捕获 CUDA Graph 以加速 decode 阶段的小批量推理
if not self.enforce_eager:
self.capture_cudagraph()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
# 张量并行时,rank > 0 的工作进程进入消息循环
if self.world_size > 1:
if rank == 0:
# 主进程创建共享内存,工作进程打开它
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
dist.barrier()
else:
dist.barrier()
self.shm = SharedMemory(name="nanovllm")
self.loop() # 工作进程在此循环,直到收到 exit 指令
def exit(self):
"""释放所有资源并退出。"""
if self.world_size > 1:
self.shm.close()
dist.barrier()
if self.rank == 0:
self.shm.unlink() # 只有创建者需要 unlink
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
dist.destroy_process_group()
def loop(self):
"""工作进程的主循环:等待主进程指令,执行对应方法。"""
while True:
method_name, args = self.read_shm()
self.call(method_name, *args)
if method_name == "exit":
break
def read_shm(self):
"""从共享内存读取主进程发送的方法调用指令。"""
assert self.world_size > 1 and self.rank > 0
self.event.wait() # 等待主进程通知
n = int.from_bytes(self.shm.buf[0:4], "little")
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
self.event.clear()
return method_name, args
def write_shm(self, method_name, *args):
"""将方法调用指令写入共享内存,通知工作进程。"""
assert self.world_size > 1 and self.rank == 0
data = pickle.dumps([method_name, *args])
n = len(data)
self.shm.buf[0:4] = n.to_bytes(4, "little")
self.shm.buf[4:n+4] = data
for event in self.event:
event.set()
def call(self, method_name, *args):
"""调用指定方法。TP 模式下主进程先通知工作进程,再本地执行。"""
if self.world_size > 1 and self.rank == 0:
self.write_shm(method_name, *args)
method = getattr(self, method_name, None)
return method(*args)
def warmup_model(self):
"""预热模型:运行一次最大批量的前向传播。
目的是让 PyTorch 分配所有内部缓存(cuBLAS workspace 等),
然后通过 empty_cache 释放临时显存,这样后续的 peak memory 统计
就只包含模型权重,从而准确计算 KV cache 可用空间。
"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
seq_len = min(max_num_batched_tokens, max_model_len)
num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs)
seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)]
for seq in seqs:
seq.num_scheduled_tokens = seq_len
self.run(seqs, True)
torch.cuda.empty_cache()
def allocate_kv_cache(self):
"""根据剩余 GPU 显存分配 KV cache。
计算公式:
可用显存 = 总显存 × gpu_memory_utilization - 非模型占用
其中非模型占用 = 已用显存 - peak(模型权重)+ current(当前模型张量)
KV cache 形状: (2, num_layers, num_blocks, block_size, num_kv_heads/head_dim)
其中第一维 2 分别对应 K 和 V cache。
"""
config = self.config
hf_config = config.hf_config
free, total = torch.cuda.mem_get_info()
used = total - free
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
# 每个块占用的字节数:2(K+V) × 层数 × block_size × KV头数 × head_dim × dtype字节数
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0
# 分配 KV cache 张量,形状为 (2, num_layers, num_blocks, block_size, num_kv_heads, head_dim)
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
# 将 KV cache 的视图绑定到模型中每个 Attention 层
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1
def prepare_block_tables(self, seqs: list[Sequence]):
"""将序列的 block_table 列表填充为等长的二维张量,用于 GPU 计算。"""
max_len = max(len(seq.block_table) for seq in seqs)
# 用 -1 填充短序列的 block_table
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
return block_tables
def prepare_prefill(self, seqs: list[Sequence]):
"""准备 prefill 阶段的模型输入张量。
Prefill 阶段需要处理多个序列的 prompt tokens,所有 token 被拼接成一个连续的输入。
使用 cu_seqlens(累积序列长度)来标记每个序列的边界,供 flash_attn_varlen 使用。
关键数据:
- input_ids: 所有序列的 token ID 拼接。
- positions: 每个 token 的位置 ID(考虑前缀缓存偏移)。
- cu_seqlens_q/k: 查询和键值的累积序列长度。
- slot_mapping: 将每个 token 映射到 KV cache 中的物理存储位置。
- block_tables: 前缀缓存命中时需要 block_table 来从 KV cache 读取已缓存的 K/V。
"""
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq in seqs:
start = seq.num_cached_tokens # 跳过已缓存的 token
seqlen_q = seq.num_scheduled_tokens
end = start + seqlen_q
seqlen_k = end # KV 的长度是从 0 到 end(包括缓存前缀)
input_ids.extend(seq[start:end])
positions.extend(range(start, end))
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup 阶段没有 block_table
continue
# 计算 slot_mapping:每个 token 对应 KV cache 中的哪个 slot
start_block = start // self.block_size
end_block = (end + self.block_size - 1) // self.block_size
for i in range(start_block, end_block):
slot_start = seq.block_table[i] * self.block_size
if i == start_block:
slot_start += start % self.block_size
if i != end_block - 1:
slot_end = seq.block_table[i] * self.block_size + self.block_size
else:
slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size
slot_mapping.extend(range(slot_start, slot_end))
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # 前缀缓存命中时,KV 长度 > Q 长度
block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
"""准备 decode 阶段的模型输入张量。
Decode 阶段每个序列只处理 1 个 token(最新生成的 token)。
模型从 KV cache 中读取之前所有的 K/V 向量来做注意力计算。
关键数据:
- input_ids: 每个序列的最新 token ID。
- positions: 每个 token 的位置 ID(序列长度 - 1)。
- slot_mapping: 新 token 的 KV 写入位置。
- context_lens: 每个序列的上下文总长度。
- block_tables: KV cache 块映射表。
"""
input_ids = []
positions = []
slot_mapping = []
context_lens = []
for seq in seqs:
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
context_lens.append(len(seq))
# slot = 最后一个块的起始位置 + 该块内已有 token 数 - 1
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
return input_ids, positions
def prepare_sample(self, seqs: list[Sequence]):
"""准备采样所需的温度参数张量。"""
temperatures = [seq.temperature for seq in seqs]
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
"""运行模型前向传播。
对于 decode 阶段的小批量(<=512),使用 CUDA Graph 加速:
CUDA Graph 将整个计算图"录制"下来,后续只需回放即可,避免了
CPU 端的 kernel launch 开销,对 decode(每个 step 计算量很小)尤为有效。
"""
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
# 直接运行:prefill(批量动态)、eager 模式、或大批量 decode
return self.model.compute_logits(self.model(input_ids, positions))
else:
# 使用 CUDA Graph 回放加速小批量 decode
bs = input_ids.size(0)
context = get_context()
# 选择 >= bs 的最小预捕获图大小
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
# 将实际输入拷贝到图预分配的固定大小缓冲区中
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
# 回放图(比重新执行快,跳过了 Python/PyTorch 调度开销)
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
"""执行一次完整的推理步骤:准备输入 → 模型前向 → 采样。
Returns:
采样得到的 token ID 列表(仅 rank 0 返回有效值)。
"""
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids
@torch.inference_mode()
def capture_cudagraph(self):
"""预捕获不同批量大小的 CUDA Graph。
CUDA Graph 要求输入张量的地址不变(同一个内存池),所以需要预分配
固定大小的输入缓冲区,并为每个 batch size 录制一个图。
预捕获的 batch size: [1, 2, 4, 8, 16, 32, ..., max_bs]
运行时选择 >= 实际 batch size 的最小预捕获图。
"""
config = self.config
hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 512)
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
# 预分配固定地址的输入/输出缓冲区
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)
# 要捕获的 batch size 列表
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
self.graphs = {}
self.graph_pool = None
# 逆序捕获:先捕获大的 batch size,共享同一个 graph pool
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup 运行
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # 捕获计算图
if self.graph_pool is None:
self.graph_pool = graph.pool() # 所有图共享同一个内存池
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)