Files
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

146 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import atexit
from dataclasses import fields
from time import perf_counter
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import torch.multiprocessing as mp
from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner
class LLMEngine:
"""LLM 推理引擎:整个 nano-vllm 的入口和协调者。
职责:
1. 初始化配置、启动张量并行的子进程。
2. 管理请求(添加/调度/生成)。
3. 协调调度器和模型运行器之间的交互。
使用方式:
>>> engine = LLMEngine("/path/to/model", enforce_eager=True)
>>> outputs = engine.generate(["Hello"], SamplingParams(max_tokens=64))
"""
def __init__(self, model, **kwargs):
# 过滤出 Config 支持的参数,忽略其他参数
config_fields = {field.name for field in fields(Config)}
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
config = Config(model, **config_kwargs)
Sequence.block_size = config.kvcache_block_size
# 启动张量并行的工作进程(rank 1, 2, ...
# 使用 "spawn" 方式创建子进程,确保 CUDA context 独立
self.ps = []
self.events = []
ctx = mp.get_context("spawn")
for i in range(1, config.tensor_parallel_size):
event = ctx.Event() # 用于通知工作进程有新任务
process = ctx.Process(target=ModelRunner, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
# 在主进程中初始化模型运行器(rank 0)
self.model_runner = ModelRunner(config, 0, self.events)
# 加载 tokenizer(用于文本 ↔ token ID 转换)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
# 初始化调度器(此时 num_kvcache_blocks 已由 ModelRunner 分配完毕)
self.scheduler = Scheduler(config)
# 注册退出清理函数
atexit.register(self.exit)
def exit(self):
"""通知所有进程退出并等待。"""
self.model_runner.call("exit")
del self.model_runner
for p in self.ps:
p.join()
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
"""添加一个生成请求到调度器的等待队列。
Args:
prompt: 可以是字符串(会被 tokenizer 编码为 token ID)或已编码的 token ID 列表。
sampling_params: 采样参数。
"""
if isinstance(prompt, str):
prompt = self.tokenizer.encode(prompt)
seq = Sequence(prompt, sampling_params)
self.scheduler.add(seq)
def step(self):
"""执行一个推理步骤。
1. 调度器选择要处理的序列(prefill 或 decode)。
2. 模型运行器执行前向推理和采样。
3. 调度器后处理结果(更新缓存、检查终止条件)。
Returns:
outputs: 已完成序列的 (seq_id, completion_token_ids) 列表。
num_tokens: prefill 时为处理的 token 总数(正数),decode 时为序列数的负数。
"""
seqs, is_prefill = self.scheduler.schedule()
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids, is_prefill)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
return outputs, num_tokens
def is_finished(self):
"""检查是否所有请求都已完成。"""
return self.scheduler.is_finished()
def generate(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
) -> list[str]:
"""批量生成文本的入口方法。
工作流程:
1. 将所有 prompt 添加为请求。
2. 循环执行 step() 直到所有序列完成。
3. 收集结果并解码为文本。
Args:
prompts: prompt 列表,每个元素可以是字符串或 token ID 列表。
sampling_params: 单个采样参数(应用于所有 prompt)或每个 prompt 对应的采样参数列表。
use_tqdm: 是否显示进度条。
Returns:
列表,每个元素是 {"text": 解码文本, "token_ids": token ID 列表}。
"""
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm)
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params):
self.add_request(prompt, sp)
outputs = {}
prefill_throughput = decode_throughput = 0.
while not self.is_finished():
t = perf_counter()
output, num_tokens = self.step()
# num_tokens > 0 表示 prefill< 0 表示 decode(取负得到序列数)
if num_tokens > 0:
prefill_throughput = num_tokens / (perf_counter() - t)
else:
decode_throughput = -num_tokens / (perf_counter() - t)
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
})
for seq_id, token_ids in output:
outputs[seq_id] = token_ids
pbar.update(1)
pbar.close()
# 按 seq_id 排序输出(保证与输入顺序一致)
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
return outputs