add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -13,40 +13,77 @@ from nanovllm.engine.model_runner import ModelRunner
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""LLM 推理引擎:整个 nano-vllm 的入口和协调者。
|
||||
|
||||
职责:
|
||||
1. 初始化配置、启动张量并行的子进程。
|
||||
2. 管理请求(添加/调度/生成)。
|
||||
3. 协调调度器和模型运行器之间的交互。
|
||||
|
||||
使用方式:
|
||||
>>> engine = LLMEngine("/path/to/model", enforce_eager=True)
|
||||
>>> outputs = engine.generate(["Hello"], SamplingParams(max_tokens=64))
|
||||
"""
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
# 过滤出 Config 支持的参数,忽略其他参数
|
||||
config_fields = {field.name for field in fields(Config)}
|
||||
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
||||
config = Config(model, **config_kwargs)
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
|
||||
# 启动张量并行的工作进程(rank 1, 2, ...)
|
||||
# 使用 "spawn" 方式创建子进程,确保 CUDA context 独立
|
||||
self.ps = []
|
||||
self.events = []
|
||||
ctx = mp.get_context("spawn")
|
||||
for i in range(1, config.tensor_parallel_size):
|
||||
event = ctx.Event()
|
||||
event = ctx.Event() # 用于通知工作进程有新任务
|
||||
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
||||
process.start()
|
||||
self.ps.append(process)
|
||||
self.events.append(event)
|
||||
|
||||
# 在主进程中初始化模型运行器(rank 0)
|
||||
self.model_runner = ModelRunner(config, 0, self.events)
|
||||
# 加载 tokenizer(用于文本 ↔ token ID 转换)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
# 初始化调度器(此时 num_kvcache_blocks 已由 ModelRunner 分配完毕)
|
||||
self.scheduler = Scheduler(config)
|
||||
# 注册退出清理函数
|
||||
atexit.register(self.exit)
|
||||
|
||||
def exit(self):
|
||||
"""通知所有进程退出并等待。"""
|
||||
self.model_runner.call("exit")
|
||||
del self.model_runner
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
|
||||
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
||||
"""添加一个生成请求到调度器的等待队列。
|
||||
|
||||
Args:
|
||||
prompt: 可以是字符串(会被 tokenizer 编码为 token ID)或已编码的 token ID 列表。
|
||||
sampling_params: 采样参数。
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
seq = Sequence(prompt, sampling_params)
|
||||
self.scheduler.add(seq)
|
||||
|
||||
def step(self):
|
||||
"""执行一个推理步骤。
|
||||
|
||||
1. 调度器选择要处理的序列(prefill 或 decode)。
|
||||
2. 模型运行器执行前向推理和采样。
|
||||
3. 调度器后处理结果(更新缓存、检查终止条件)。
|
||||
|
||||
Returns:
|
||||
outputs: 已完成序列的 (seq_id, completion_token_ids) 列表。
|
||||
num_tokens: prefill 时为处理的 token 总数(正数),decode 时为序列数的负数。
|
||||
"""
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
|
||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||
@@ -55,6 +92,7 @@ class LLMEngine:
|
||||
return outputs, num_tokens
|
||||
|
||||
def is_finished(self):
|
||||
"""检查是否所有请求都已完成。"""
|
||||
return self.scheduler.is_finished()
|
||||
|
||||
def generate(
|
||||
@@ -63,6 +101,21 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
"""批量生成文本的入口方法。
|
||||
|
||||
工作流程:
|
||||
1. 将所有 prompt 添加为请求。
|
||||
2. 循环执行 step() 直到所有序列完成。
|
||||
3. 收集结果并解码为文本。
|
||||
|
||||
Args:
|
||||
prompts: prompt 列表,每个元素可以是字符串或 token ID 列表。
|
||||
sampling_params: 单个采样参数(应用于所有 prompt)或每个 prompt 对应的采样参数列表。
|
||||
use_tqdm: 是否显示进度条。
|
||||
|
||||
Returns:
|
||||
列表,每个元素是 {"text": 解码文本, "token_ids": token ID 列表}。
|
||||
"""
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm)
|
||||
if not isinstance(sampling_params, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
@@ -73,6 +126,7 @@ class LLMEngine:
|
||||
while not self.is_finished():
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
# num_tokens > 0 表示 prefill,< 0 表示 decode(取负得到序列数)
|
||||
if num_tokens > 0:
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
@@ -85,6 +139,7 @@ class LLMEngine:
|
||||
outputs[seq_id] = token_ids
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
# 按 seq_id 排序输出(保证与输入顺序一致)
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user