import atexit from dataclasses import fields from time import perf_counter from tqdm.auto import tqdm from transformers import AutoTokenizer import torch.multiprocessing as mp from nanovllm.config import Config from nanovllm.sampling_params import SamplingParams from nanovllm.engine.sequence import Sequence from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.model_runner import ModelRunner class LLMEngine: """LLM 推理引擎:整个 nano-vllm 的入口和协调者。 职责: 1. 初始化配置、启动张量并行的子进程。 2. 管理请求(添加/调度/生成)。 3. 协调调度器和模型运行器之间的交互。 使用方式: >>> engine = LLMEngine("/path/to/model", enforce_eager=True) >>> outputs = engine.generate(["Hello"], SamplingParams(max_tokens=64)) """ def __init__(self, model, **kwargs): # 过滤出 Config 支持的参数,忽略其他参数 config_fields = {field.name for field in fields(Config)} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} config = Config(model, **config_kwargs) Sequence.block_size = config.kvcache_block_size # 启动张量并行的工作进程(rank 1, 2, ...) # 使用 "spawn" 方式创建子进程,确保 CUDA context 独立 self.ps = [] self.events = [] ctx = mp.get_context("spawn") for i in range(1, config.tensor_parallel_size): event = ctx.Event() # 用于通知工作进程有新任务 process = ctx.Process(target=ModelRunner, args=(config, i, event)) process.start() self.ps.append(process) self.events.append(event) # 在主进程中初始化模型运行器(rank 0) self.model_runner = ModelRunner(config, 0, self.events) # 加载 tokenizer(用于文本 ↔ token ID 转换) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) config.eos = self.tokenizer.eos_token_id # 初始化调度器(此时 num_kvcache_blocks 已由 ModelRunner 分配完毕) self.scheduler = Scheduler(config) # 注册退出清理函数 atexit.register(self.exit) def exit(self): """通知所有进程退出并等待。""" self.model_runner.call("exit") del self.model_runner for p in self.ps: p.join() def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): """添加一个生成请求到调度器的等待队列。 Args: prompt: 可以是字符串(会被 tokenizer 编码为 token ID)或已编码的 token ID 列表。 sampling_params: 采样参数。 """ if isinstance(prompt, str): prompt = self.tokenizer.encode(prompt) seq = Sequence(prompt, sampling_params) self.scheduler.add(seq) def step(self): """执行一个推理步骤。 1. 调度器选择要处理的序列(prefill 或 decode)。 2. 模型运行器执行前向推理和采样。 3. 调度器后处理结果(更新缓存、检查终止条件)。 Returns: outputs: 已完成序列的 (seq_id, completion_token_ids) 列表。 num_tokens: prefill 时为处理的 token 总数(正数),decode 时为序列数的负数。 """ seqs, is_prefill = self.scheduler.schedule() num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs) token_ids = self.model_runner.call("run", seqs, is_prefill) self.scheduler.postprocess(seqs, token_ids, is_prefill) outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] return outputs, num_tokens def is_finished(self): """检查是否所有请求都已完成。""" return self.scheduler.is_finished() def generate( self, prompts: list[str] | list[list[int]], sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True, ) -> list[str]: """批量生成文本的入口方法。 工作流程: 1. 将所有 prompt 添加为请求。 2. 循环执行 step() 直到所有序列完成。 3. 收集结果并解码为文本。 Args: prompts: prompt 列表,每个元素可以是字符串或 token ID 列表。 sampling_params: 单个采样参数(应用于所有 prompt)或每个 prompt 对应的采样参数列表。 use_tqdm: 是否显示进度条。 Returns: 列表,每个元素是 {"text": 解码文本, "token_ids": token ID 列表}。 """ pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm) if not isinstance(sampling_params, list): sampling_params = [sampling_params] * len(prompts) for prompt, sp in zip(prompts, sampling_params): self.add_request(prompt, sp) outputs = {} prefill_throughput = decode_throughput = 0. while not self.is_finished(): t = perf_counter() output, num_tokens = self.step() # num_tokens > 0 表示 prefill,< 0 表示 decode(取负得到序列数) if num_tokens > 0: prefill_throughput = num_tokens / (perf_counter() - t) else: decode_throughput = -num_tokens / (perf_counter() - t) pbar.set_postfix({ "Prefill": f"{int(prefill_throughput)}tok/s", "Decode": f"{int(decode_throughput)}tok/s", }) for seq_id, token_ids in output: outputs[seq_id] = token_ids pbar.update(1) pbar.close() # 按 seq_id 排序输出(保证与输入顺序一致) outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] return outputs