ffd2defdfc
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
146 lines
6.0 KiB
Python
146 lines
6.0 KiB
Python
import atexit
|
||
from dataclasses import fields
|
||
from time import perf_counter
|
||
from tqdm.auto import tqdm
|
||
from transformers import AutoTokenizer
|
||
import torch.multiprocessing as mp
|
||
|
||
from nanovllm.config import Config
|
||
from nanovllm.sampling_params import SamplingParams
|
||
from nanovllm.engine.sequence import Sequence
|
||
from nanovllm.engine.scheduler import Scheduler
|
||
from nanovllm.engine.model_runner import ModelRunner
|
||
|
||
|
||
class LLMEngine:
|
||
"""LLM 推理引擎:整个 nano-vllm 的入口和协调者。
|
||
|
||
职责:
|
||
1. 初始化配置、启动张量并行的子进程。
|
||
2. 管理请求(添加/调度/生成)。
|
||
3. 协调调度器和模型运行器之间的交互。
|
||
|
||
使用方式:
|
||
>>> engine = LLMEngine("/path/to/model", enforce_eager=True)
|
||
>>> outputs = engine.generate(["Hello"], SamplingParams(max_tokens=64))
|
||
"""
|
||
|
||
def __init__(self, model, **kwargs):
|
||
# 过滤出 Config 支持的参数,忽略其他参数
|
||
config_fields = {field.name for field in fields(Config)}
|
||
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
||
config = Config(model, **config_kwargs)
|
||
Sequence.block_size = config.kvcache_block_size
|
||
|
||
# 启动张量并行的工作进程(rank 1, 2, ...)
|
||
# 使用 "spawn" 方式创建子进程,确保 CUDA context 独立
|
||
self.ps = []
|
||
self.events = []
|
||
ctx = mp.get_context("spawn")
|
||
for i in range(1, config.tensor_parallel_size):
|
||
event = ctx.Event() # 用于通知工作进程有新任务
|
||
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
||
process.start()
|
||
self.ps.append(process)
|
||
self.events.append(event)
|
||
|
||
# 在主进程中初始化模型运行器(rank 0)
|
||
self.model_runner = ModelRunner(config, 0, self.events)
|
||
# 加载 tokenizer(用于文本 ↔ token ID 转换)
|
||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||
config.eos = self.tokenizer.eos_token_id
|
||
# 初始化调度器(此时 num_kvcache_blocks 已由 ModelRunner 分配完毕)
|
||
self.scheduler = Scheduler(config)
|
||
# 注册退出清理函数
|
||
atexit.register(self.exit)
|
||
|
||
def exit(self):
|
||
"""通知所有进程退出并等待。"""
|
||
self.model_runner.call("exit")
|
||
del self.model_runner
|
||
for p in self.ps:
|
||
p.join()
|
||
|
||
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
||
"""添加一个生成请求到调度器的等待队列。
|
||
|
||
Args:
|
||
prompt: 可以是字符串(会被 tokenizer 编码为 token ID)或已编码的 token ID 列表。
|
||
sampling_params: 采样参数。
|
||
"""
|
||
if isinstance(prompt, str):
|
||
prompt = self.tokenizer.encode(prompt)
|
||
seq = Sequence(prompt, sampling_params)
|
||
self.scheduler.add(seq)
|
||
|
||
def step(self):
|
||
"""执行一个推理步骤。
|
||
|
||
1. 调度器选择要处理的序列(prefill 或 decode)。
|
||
2. 模型运行器执行前向推理和采样。
|
||
3. 调度器后处理结果(更新缓存、检查终止条件)。
|
||
|
||
Returns:
|
||
outputs: 已完成序列的 (seq_id, completion_token_ids) 列表。
|
||
num_tokens: prefill 时为处理的 token 总数(正数),decode 时为序列数的负数。
|
||
"""
|
||
seqs, is_prefill = self.scheduler.schedule()
|
||
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
|
||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||
self.scheduler.postprocess(seqs, token_ids, is_prefill)
|
||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||
return outputs, num_tokens
|
||
|
||
def is_finished(self):
|
||
"""检查是否所有请求都已完成。"""
|
||
return self.scheduler.is_finished()
|
||
|
||
def generate(
|
||
self,
|
||
prompts: list[str] | list[list[int]],
|
||
sampling_params: SamplingParams | list[SamplingParams],
|
||
use_tqdm: bool = True,
|
||
) -> list[str]:
|
||
"""批量生成文本的入口方法。
|
||
|
||
工作流程:
|
||
1. 将所有 prompt 添加为请求。
|
||
2. 循环执行 step() 直到所有序列完成。
|
||
3. 收集结果并解码为文本。
|
||
|
||
Args:
|
||
prompts: prompt 列表,每个元素可以是字符串或 token ID 列表。
|
||
sampling_params: 单个采样参数(应用于所有 prompt)或每个 prompt 对应的采样参数列表。
|
||
use_tqdm: 是否显示进度条。
|
||
|
||
Returns:
|
||
列表,每个元素是 {"text": 解码文本, "token_ids": token ID 列表}。
|
||
"""
|
||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm)
|
||
if not isinstance(sampling_params, list):
|
||
sampling_params = [sampling_params] * len(prompts)
|
||
for prompt, sp in zip(prompts, sampling_params):
|
||
self.add_request(prompt, sp)
|
||
outputs = {}
|
||
prefill_throughput = decode_throughput = 0.
|
||
while not self.is_finished():
|
||
t = perf_counter()
|
||
output, num_tokens = self.step()
|
||
# num_tokens > 0 表示 prefill,< 0 表示 decode(取负得到序列数)
|
||
if num_tokens > 0:
|
||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||
else:
|
||
decode_throughput = -num_tokens / (perf_counter() - t)
|
||
pbar.set_postfix({
|
||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||
"Decode": f"{int(decode_throughput)}tok/s",
|
||
})
|
||
for seq_id, token_ids in output:
|
||
outputs[seq_id] = token_ids
|
||
pbar.update(1)
|
||
pbar.close()
|
||
# 按 seq_id 排序输出(保证与输入顺序一致)
|
||
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||
return outputs
|