add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,16 @@ from nanovllm.engine.sequence import Sequence
|
||||
|
||||
|
||||
class Block:
|
||||
"""KV cache 的物理块。
|
||||
|
||||
每个 Block 对应 GPU 显存中一块固定大小的 KV cache 存储,可以存放 block_size 个 token 的 KV 向量。
|
||||
|
||||
Attributes:
|
||||
block_id: 物理块的唯一 ID。
|
||||
ref_count: 引用计数,有多少个序列正在共享这个块(前缀缓存复用)。
|
||||
hash: 该块内容的哈希值,用于前缀缓存查找。初始为 -1 表示未计算哈希。
|
||||
token_ids: 该块对应的 token ID 列表,用于验证缓存命中时内容是否一致。
|
||||
"""
|
||||
|
||||
def __init__(self, block_id):
|
||||
self.block_id = block_id
|
||||
@@ -14,16 +24,33 @@ class Block:
|
||||
self.token_ids = []
|
||||
|
||||
def update(self, hash: int, token_ids: list[int]):
|
||||
"""更新块的内容哈希和 token_ids。"""
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
def reset(self):
|
||||
"""重置块状态,准备被重新分配。"""
|
||||
self.ref_count = 1
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
|
||||
class BlockManager:
|
||||
"""管理 KV cache 物理块的分配、释放和前缀缓存。
|
||||
|
||||
核心设计:
|
||||
- 将 KV cache 分成固定大小的物理块(block_size 个 token/块)。
|
||||
- 使用哈希表实现前缀缓存(prefix caching):不同序列如果拥有相同的 prompt 前缀,
|
||||
可以共享同一组 KV cache 块,避免重复计算。
|
||||
- 块通过引用计数(ref_count)管理生命周期,所有引用释放后块变为空闲可复用。
|
||||
|
||||
Attributes:
|
||||
block_size: 每个块存储的 token 数量。
|
||||
blocks: 所有物理块的列表,索引即 block_id。
|
||||
hash_to_block_id: 内容哈希 → 块 ID 的映射,用于前缀缓存查找。
|
||||
free_block_ids: 空闲块的 ID 队列。
|
||||
used_block_ids: 正在使用的块 ID 集合。
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
self.block_size = block_size
|
||||
@@ -34,6 +61,15 @@ class BlockManager:
|
||||
|
||||
@classmethod
|
||||
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
||||
"""计算一个块的内容哈希值。
|
||||
|
||||
哈希是链式的:每个块的哈希值依赖于前一个块的哈希值(prefix 参数),
|
||||
这确保了前缀缓存的一致性——只有完全相同的 token 序列前缀才会产生相同的哈希链。
|
||||
|
||||
Args:
|
||||
token_ids: 该块对应的 token ID 列表。
|
||||
prefix: 前一个块的哈希值,-1 表示第一个块(无前缀)。
|
||||
"""
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8, "little"))
|
||||
@@ -41,9 +77,11 @@ class BlockManager:
|
||||
return h.intdigest()
|
||||
|
||||
def _allocate_block(self) -> int:
|
||||
"""从空闲池中分配一个物理块。"""
|
||||
block_id = self.free_block_ids.popleft()
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_count == 0
|
||||
# 如果该块之前有哈希记录,先从哈希表中移除
|
||||
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
|
||||
del self.hash_to_block_id[block.hash]
|
||||
block.reset()
|
||||
@@ -51,11 +89,22 @@ class BlockManager:
|
||||
return block_id
|
||||
|
||||
def _deallocate_block(self, block_id: int):
|
||||
"""释放一个物理块回空闲池。"""
|
||||
assert self.blocks[block_id].ref_count == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> int:
|
||||
"""检查是否有足够的空闲块来分配给序列,同时计算前缀缓存命中数。
|
||||
|
||||
遍历序列的所有逻辑块,逐个计算哈希并与已有缓存比对:
|
||||
- 如果哈希匹配且 token 内容一致,则该块可以复用(前缀缓存命中)。
|
||||
- 一旦遇到不匹配的块就停止,因为前缀缓存要求连续匹配。
|
||||
- 最后检查剩余空闲块是否足够分配未命中部分。
|
||||
|
||||
Returns:
|
||||
前缀缓存命中的块数,如果没有足够的空闲块则返回 -1。
|
||||
"""
|
||||
h = -1
|
||||
num_cached_blocks = 0
|
||||
num_new_blocks = seq.num_blocks
|
||||
@@ -67,14 +116,25 @@ class BlockManager:
|
||||
break
|
||||
num_cached_blocks += 1
|
||||
if block_id in self.used_block_ids:
|
||||
# 块已在使用中(被其他序列共享),不需要额外分配
|
||||
num_new_blocks -= 1
|
||||
if len(self.free_block_ids) < num_new_blocks:
|
||||
return -1
|
||||
return num_cached_blocks
|
||||
|
||||
def allocate(self, seq: Sequence, num_cached_blocks: int):
|
||||
"""为序列分配 KV cache 物理块。
|
||||
|
||||
前缀缓存命中部分:增加已有块的引用计数。
|
||||
未命中部分:从空闲池分配新块。
|
||||
|
||||
Args:
|
||||
seq: 要分配块的序列。
|
||||
num_cached_blocks: 前缀缓存命中的块数(由 can_allocate 计算)。
|
||||
"""
|
||||
assert not seq.block_table
|
||||
h = -1
|
||||
# 处理缓存命中部分:复用已有块
|
||||
for i in range(num_cached_blocks):
|
||||
token_ids = seq.block(i)
|
||||
h = self.compute_hash(token_ids, h)
|
||||
@@ -83,15 +143,22 @@ class BlockManager:
|
||||
if block_id in self.used_block_ids:
|
||||
block.ref_count += 1
|
||||
else:
|
||||
# 块在空闲池中但哈希匹配(之前被释放但内容未覆盖),重新激活
|
||||
block.ref_count = 1
|
||||
self.free_block_ids.remove(block_id)
|
||||
self.used_block_ids.add(block_id)
|
||||
seq.block_table.append(block_id)
|
||||
# 分配新块用于未命中部分
|
||||
for i in range(num_cached_blocks, seq.num_blocks):
|
||||
seq.block_table.append(self._allocate_block())
|
||||
seq.num_cached_tokens = num_cached_blocks * self.block_size
|
||||
|
||||
def deallocate(self, seq: Sequence):
|
||||
"""释放序列占用的所有 KV cache 块。
|
||||
|
||||
递减每个块的引用计数,引用计数归零的块被回收进空闲池。
|
||||
注意:遍历顺序为逆序,这是为了优先释放最新分配的块(它们在空闲队列尾部)。
|
||||
"""
|
||||
for block_id in reversed(seq.block_table):
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count -= 1
|
||||
@@ -101,16 +168,30 @@ class BlockManager:
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
"""检查 decode 阶段是否能追加一个 token。
|
||||
|
||||
当序列当前最后一个块已满时(len(seq) % block_size == 1,即新增 token 是新块的第一个),
|
||||
需要分配一个新块。否则只需要写入已有块,不需要额外分配。
|
||||
"""
|
||||
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||
|
||||
def may_append(self, seq: Sequence):
|
||||
"""在 decode 阶段,如果需要则追加一个新块。"""
|
||||
if len(seq) % self.block_size == 1:
|
||||
seq.block_table.append(self._allocate_block())
|
||||
|
||||
def hash_blocks(self, seq: Sequence):
|
||||
"""更新序列中已完成计算的块的哈希值。
|
||||
|
||||
在每次调度步骤完成后调用,将新计算完毕的块的 token 内容和哈希值
|
||||
注册到哈希表中,以便后续序列复用(前缀缓存)。
|
||||
|
||||
只处理从 num_cached_tokens 到当前进度的块(即本次新完成的块)。
|
||||
"""
|
||||
start = seq.num_cached_tokens // self.block_size
|
||||
end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size
|
||||
if start == end: return
|
||||
# 继承前一个块的哈希值作为前缀
|
||||
h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1
|
||||
for i in range(start, end):
|
||||
block = self.blocks[seq.block_table[i]]
|
||||
|
||||
@@ -13,40 +13,77 @@ from nanovllm.engine.model_runner import ModelRunner
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""LLM 推理引擎:整个 nano-vllm 的入口和协调者。
|
||||
|
||||
职责:
|
||||
1. 初始化配置、启动张量并行的子进程。
|
||||
2. 管理请求(添加/调度/生成)。
|
||||
3. 协调调度器和模型运行器之间的交互。
|
||||
|
||||
使用方式:
|
||||
>>> engine = LLMEngine("/path/to/model", enforce_eager=True)
|
||||
>>> outputs = engine.generate(["Hello"], SamplingParams(max_tokens=64))
|
||||
"""
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
# 过滤出 Config 支持的参数,忽略其他参数
|
||||
config_fields = {field.name for field in fields(Config)}
|
||||
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
||||
config = Config(model, **config_kwargs)
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
|
||||
# 启动张量并行的工作进程(rank 1, 2, ...)
|
||||
# 使用 "spawn" 方式创建子进程,确保 CUDA context 独立
|
||||
self.ps = []
|
||||
self.events = []
|
||||
ctx = mp.get_context("spawn")
|
||||
for i in range(1, config.tensor_parallel_size):
|
||||
event = ctx.Event()
|
||||
event = ctx.Event() # 用于通知工作进程有新任务
|
||||
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
||||
process.start()
|
||||
self.ps.append(process)
|
||||
self.events.append(event)
|
||||
|
||||
# 在主进程中初始化模型运行器(rank 0)
|
||||
self.model_runner = ModelRunner(config, 0, self.events)
|
||||
# 加载 tokenizer(用于文本 ↔ token ID 转换)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
# 初始化调度器(此时 num_kvcache_blocks 已由 ModelRunner 分配完毕)
|
||||
self.scheduler = Scheduler(config)
|
||||
# 注册退出清理函数
|
||||
atexit.register(self.exit)
|
||||
|
||||
def exit(self):
|
||||
"""通知所有进程退出并等待。"""
|
||||
self.model_runner.call("exit")
|
||||
del self.model_runner
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
|
||||
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
||||
"""添加一个生成请求到调度器的等待队列。
|
||||
|
||||
Args:
|
||||
prompt: 可以是字符串(会被 tokenizer 编码为 token ID)或已编码的 token ID 列表。
|
||||
sampling_params: 采样参数。
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
seq = Sequence(prompt, sampling_params)
|
||||
self.scheduler.add(seq)
|
||||
|
||||
def step(self):
|
||||
"""执行一个推理步骤。
|
||||
|
||||
1. 调度器选择要处理的序列(prefill 或 decode)。
|
||||
2. 模型运行器执行前向推理和采样。
|
||||
3. 调度器后处理结果(更新缓存、检查终止条件)。
|
||||
|
||||
Returns:
|
||||
outputs: 已完成序列的 (seq_id, completion_token_ids) 列表。
|
||||
num_tokens: prefill 时为处理的 token 总数(正数),decode 时为序列数的负数。
|
||||
"""
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
|
||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||
@@ -55,6 +92,7 @@ class LLMEngine:
|
||||
return outputs, num_tokens
|
||||
|
||||
def is_finished(self):
|
||||
"""检查是否所有请求都已完成。"""
|
||||
return self.scheduler.is_finished()
|
||||
|
||||
def generate(
|
||||
@@ -63,6 +101,21 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
"""批量生成文本的入口方法。
|
||||
|
||||
工作流程:
|
||||
1. 将所有 prompt 添加为请求。
|
||||
2. 循环执行 step() 直到所有序列完成。
|
||||
3. 收集结果并解码为文本。
|
||||
|
||||
Args:
|
||||
prompts: prompt 列表,每个元素可以是字符串或 token ID 列表。
|
||||
sampling_params: 单个采样参数(应用于所有 prompt)或每个 prompt 对应的采样参数列表。
|
||||
use_tqdm: 是否显示进度条。
|
||||
|
||||
Returns:
|
||||
列表,每个元素是 {"text": 解码文本, "token_ids": token ID 列表}。
|
||||
"""
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm)
|
||||
if not isinstance(sampling_params, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
@@ -73,6 +126,7 @@ class LLMEngine:
|
||||
while not self.is_finished():
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
# num_tokens > 0 表示 prefill,< 0 表示 decode(取负得到序列数)
|
||||
if num_tokens > 0:
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
@@ -85,6 +139,7 @@ class LLMEngine:
|
||||
outputs[seq_id] = token_ids
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
# 按 seq_id 排序输出(保证与输入顺序一致)
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
return outputs
|
||||
|
||||
+111
-10
@@ -13,6 +13,18 @@ from nanovllm.utils.loader import load_model
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""模型运行器:负责模型推理、KV cache 管理、CUDA Graph 捕获和张量并行通信。
|
||||
|
||||
在张量并行(TP)模式下:
|
||||
- Rank 0 是主进程,负责采样和与引擎通信。
|
||||
- Rank > 0 是工作进程,通过共享内存(SharedMemory)接收指令。
|
||||
- 所有进程共享同一个模型和 KV cache 的分片。
|
||||
|
||||
生命周期:
|
||||
1. 初始化: 加载模型 → warmup → 分配 KV cache → (可选)捕获 CUDA Graph
|
||||
2. 推理: 接收序列 → 准备输入 → 运行模型 → 采样 token
|
||||
3. 退出: 释放资源
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
||||
self.config = config
|
||||
@@ -23,42 +35,54 @@ class ModelRunner:
|
||||
self.rank = rank
|
||||
self.event = event
|
||||
|
||||
# 初始化分布式进程组(NCCL 后端),所有 GPU 通过 TCP 通信
|
||||
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
# 加载模型权重
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(hf_config.dtype)
|
||||
torch.set_default_device("cuda")
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
load_model(self.model, config.model)
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Warmup: 运行一次前向传播以确定模型本身的显存占用
|
||||
self.warmup_model()
|
||||
# 根据剩余显存分配 KV cache
|
||||
self.allocate_kv_cache()
|
||||
# 捕获 CUDA Graph 以加速 decode 阶段的小批量推理
|
||||
if not self.enforce_eager:
|
||||
self.capture_cudagraph()
|
||||
|
||||
torch.set_default_device("cpu")
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
# 张量并行时,rank > 0 的工作进程进入消息循环
|
||||
if self.world_size > 1:
|
||||
if rank == 0:
|
||||
# 主进程创建共享内存,工作进程打开它
|
||||
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
self.shm = SharedMemory(name="nanovllm")
|
||||
self.loop()
|
||||
self.loop() # 工作进程在此循环,直到收到 exit 指令
|
||||
|
||||
def exit(self):
|
||||
"""释放所有资源并退出。"""
|
||||
if self.world_size > 1:
|
||||
self.shm.close()
|
||||
dist.barrier()
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
self.shm.unlink() # 只有创建者需要 unlink
|
||||
if not self.enforce_eager:
|
||||
del self.graphs, self.graph_pool
|
||||
torch.cuda.synchronize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def loop(self):
|
||||
"""工作进程的主循环:等待主进程指令,执行对应方法。"""
|
||||
while True:
|
||||
method_name, args = self.read_shm()
|
||||
self.call(method_name, *args)
|
||||
@@ -66,14 +90,16 @@ class ModelRunner:
|
||||
break
|
||||
|
||||
def read_shm(self):
|
||||
"""从共享内存读取主进程发送的方法调用指令。"""
|
||||
assert self.world_size > 1 and self.rank > 0
|
||||
self.event.wait()
|
||||
self.event.wait() # 等待主进程通知
|
||||
n = int.from_bytes(self.shm.buf[0:4], "little")
|
||||
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
||||
self.event.clear()
|
||||
return method_name, args
|
||||
|
||||
def write_shm(self, method_name, *args):
|
||||
"""将方法调用指令写入共享内存,通知工作进程。"""
|
||||
assert self.world_size > 1 and self.rank == 0
|
||||
data = pickle.dumps([method_name, *args])
|
||||
n = len(data)
|
||||
@@ -83,12 +109,19 @@ class ModelRunner:
|
||||
event.set()
|
||||
|
||||
def call(self, method_name, *args):
|
||||
"""调用指定方法。TP 模式下主进程先通知工作进程,再本地执行。"""
|
||||
if self.world_size > 1 and self.rank == 0:
|
||||
self.write_shm(method_name, *args)
|
||||
method = getattr(self, method_name, None)
|
||||
return method(*args)
|
||||
|
||||
def warmup_model(self):
|
||||
"""预热模型:运行一次最大批量的前向传播。
|
||||
|
||||
目的是让 PyTorch 分配所有内部缓存(cuBLAS workspace 等),
|
||||
然后通过 empty_cache 释放临时显存,这样后续的 peak memory 统计
|
||||
就只包含模型权重,从而准确计算 KV cache 可用空间。
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
||||
@@ -101,6 +134,15 @@ class ModelRunner:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def allocate_kv_cache(self):
|
||||
"""根据剩余 GPU 显存分配 KV cache。
|
||||
|
||||
计算公式:
|
||||
可用显存 = 总显存 × gpu_memory_utilization - 非模型占用
|
||||
其中非模型占用 = 已用显存 - peak(模型权重)+ current(当前模型张量)
|
||||
|
||||
KV cache 形状: (2, num_layers, num_blocks, block_size, num_kv_heads/head_dim)
|
||||
其中第一维 2 分别对应 K 和 V cache。
|
||||
"""
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
@@ -109,10 +151,13 @@ class ModelRunner:
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
# 每个块占用的字节数:2(K+V) × 层数 × block_size × KV头数 × head_dim × dtype字节数
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
|
||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert config.num_kvcache_blocks > 0
|
||||
# 分配 KV cache 张量,形状为 (2, num_layers, num_blocks, block_size, num_kv_heads, head_dim)
|
||||
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
||||
# 将 KV cache 的视图绑定到模型中每个 Attention 层
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
@@ -121,12 +166,26 @@ class ModelRunner:
|
||||
layer_id += 1
|
||||
|
||||
def prepare_block_tables(self, seqs: list[Sequence]):
|
||||
"""将序列的 block_table 列表填充为等长的二维张量,用于 GPU 计算。"""
|
||||
max_len = max(len(seq.block_table) for seq in seqs)
|
||||
# 用 -1 填充短序列的 block_table
|
||||
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
return block_tables
|
||||
|
||||
def prepare_prefill(self, seqs: list[Sequence]):
|
||||
"""准备 prefill 阶段的模型输入张量。
|
||||
|
||||
Prefill 阶段需要处理多个序列的 prompt tokens,所有 token 被拼接成一个连续的输入。
|
||||
使用 cu_seqlens(累积序列长度)来标记每个序列的边界,供 flash_attn_varlen 使用。
|
||||
|
||||
关键数据:
|
||||
- input_ids: 所有序列的 token ID 拼接。
|
||||
- positions: 每个 token 的位置 ID(考虑前缀缓存偏移)。
|
||||
- cu_seqlens_q/k: 查询和键值的累积序列长度。
|
||||
- slot_mapping: 将每个 token 映射到 KV cache 中的物理存储位置。
|
||||
- block_tables: 前缀缓存命中时需要 block_table 来从 KV cache 读取已缓存的 K/V。
|
||||
"""
|
||||
input_ids = []
|
||||
positions = []
|
||||
cu_seqlens_q = [0]
|
||||
@@ -136,18 +195,19 @@ class ModelRunner:
|
||||
slot_mapping = []
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
start = seq.num_cached_tokens
|
||||
start = seq.num_cached_tokens # 跳过已缓存的 token
|
||||
seqlen_q = seq.num_scheduled_tokens
|
||||
end = start + seqlen_q
|
||||
seqlen_k = end
|
||||
seqlen_k = end # KV 的长度是从 0 到 end(包括缓存前缀)
|
||||
input_ids.extend(seq[start:end])
|
||||
positions.extend(range(start, end))
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
if not seq.block_table: # warmup
|
||||
if not seq.block_table: # warmup 阶段没有 block_table
|
||||
continue
|
||||
# 计算 slot_mapping:每个 token 对应 KV cache 中的哪个 slot
|
||||
start_block = start // self.block_size
|
||||
end_block = (end + self.block_size - 1) // self.block_size
|
||||
for i in range(start_block, end_block):
|
||||
@@ -159,7 +219,7 @@ class ModelRunner:
|
||||
else:
|
||||
slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size
|
||||
slot_mapping.extend(range(slot_start, slot_end))
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # 前缀缓存命中时,KV 长度 > Q 长度
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -170,6 +230,18 @@ class ModelRunner:
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
"""准备 decode 阶段的模型输入张量。
|
||||
|
||||
Decode 阶段每个序列只处理 1 个 token(最新生成的 token)。
|
||||
模型从 KV cache 中读取之前所有的 K/V 向量来做注意力计算。
|
||||
|
||||
关键数据:
|
||||
- input_ids: 每个序列的最新 token ID。
|
||||
- positions: 每个 token 的位置 ID(序列长度 - 1)。
|
||||
- slot_mapping: 新 token 的 KV 写入位置。
|
||||
- context_lens: 每个序列的上下文总长度。
|
||||
- block_tables: KV cache 块映射表。
|
||||
"""
|
||||
input_ids = []
|
||||
positions = []
|
||||
slot_mapping = []
|
||||
@@ -178,6 +250,7 @@ class ModelRunner:
|
||||
input_ids.append(seq.last_token)
|
||||
positions.append(len(seq) - 1)
|
||||
context_lens.append(len(seq))
|
||||
# slot = 最后一个块的起始位置 + 该块内已有 token 数 - 1
|
||||
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -188,19 +261,30 @@ class ModelRunner:
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_sample(self, seqs: list[Sequence]):
|
||||
"""准备采样所需的温度参数张量。"""
|
||||
temperatures = [seq.temperature for seq in seqs]
|
||||
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
||||
return temperatures
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
||||
"""运行模型前向传播。
|
||||
|
||||
对于 decode 阶段的小批量(<=512),使用 CUDA Graph 加速:
|
||||
CUDA Graph 将整个计算图"录制"下来,后续只需回放即可,避免了
|
||||
CPU 端的 kernel launch 开销,对 decode(每个 step 计算量很小)尤为有效。
|
||||
"""
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
||||
# 直接运行:prefill(批量动态)、eager 模式、或大批量 decode
|
||||
return self.model.compute_logits(self.model(input_ids, positions))
|
||||
else:
|
||||
# 使用 CUDA Graph 回放加速小批量 decode
|
||||
bs = input_ids.size(0)
|
||||
context = get_context()
|
||||
# 选择 >= bs 的最小预捕获图大小
|
||||
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||
graph_vars = self.graph_vars
|
||||
# 将实际输入拷贝到图预分配的固定大小缓冲区中
|
||||
graph_vars["input_ids"][:bs] = input_ids
|
||||
graph_vars["positions"][:bs] = positions
|
||||
graph_vars["slot_mapping"].fill_(-1)
|
||||
@@ -208,10 +292,16 @@ class ModelRunner:
|
||||
graph_vars["context_lens"].zero_()
|
||||
graph_vars["context_lens"][:bs] = context.context_lens
|
||||
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
||||
# 回放图(比重新执行快,跳过了 Python/PyTorch 调度开销)
|
||||
graph.replay()
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
"""执行一次完整的推理步骤:准备输入 → 模型前向 → 采样。
|
||||
|
||||
Returns:
|
||||
采样得到的 token ID 列表(仅 rank 0 返回有效值)。
|
||||
"""
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
@@ -221,28 +311,39 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_cudagraph(self):
|
||||
"""预捕获不同批量大小的 CUDA Graph。
|
||||
|
||||
CUDA Graph 要求输入张量的地址不变(同一个内存池),所以需要预分配
|
||||
固定大小的输入缓冲区,并为每个 batch size 录制一个图。
|
||||
|
||||
预捕获的 batch size: [1, 2, 4, 8, 16, 32, ..., max_bs]
|
||||
运行时选择 >= 实际 batch size 的最小预捕获图。
|
||||
"""
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
max_bs = min(self.config.max_num_seqs, 512)
|
||||
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||
# 预分配固定地址的输入/输出缓冲区
|
||||
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
||||
positions = torch.zeros(max_bs, dtype=torch.int64)
|
||||
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||
# 要捕获的 batch size 列表
|
||||
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
||||
self.graphs = {}
|
||||
self.graph_pool = None
|
||||
|
||||
# 逆序捕获:先捕获大的 batch size,共享同一个 graph pool
|
||||
for bs in reversed(self.graph_bs):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup 运行
|
||||
with torch.cuda.graph(graph, self.graph_pool):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # 捕获计算图
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = graph.pool()
|
||||
self.graph_pool = graph.pool() # 所有图共享同一个内存池
|
||||
self.graphs[bs] = graph
|
||||
torch.cuda.synchronize()
|
||||
reset_context()
|
||||
|
||||
@@ -6,6 +6,22 @@ from nanovllm.engine.block_manager import BlockManager
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""调度器:决定每个步骤(step)中哪些序列被处理以及处理多少 token。
|
||||
|
||||
调度策略采用 vLLM 风格的 prefill-decode 分离调度:
|
||||
1. **Prefill 优先**: 每次调度先尝试处理等待中的序列(计算其 prompt 的 KV cache)。
|
||||
2. **Chunked prefill**: 如果一个序列的 prompt 太长,可以分多次调度处理。
|
||||
3. **Decode**: 当没有等待中的序列时(或 prefill token 额度用完),处理正在解码的序列。
|
||||
4. **抢占(Preemption)**: 当 KV cache 空间不足时,将最近的 running 序列抢占回 waiting 队列。
|
||||
|
||||
调度约束:
|
||||
- 总 token 数不超过 max_num_batched_tokens(prefill 阶段)。
|
||||
- 总序列数不超过 max_num_seqs。
|
||||
|
||||
Attributes:
|
||||
waiting: 等待处理的序列队列(FIFO)。
|
||||
running: 正在解码的序列队列。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.max_num_seqs = config.max_num_seqs
|
||||
@@ -17,50 +33,79 @@ class Scheduler:
|
||||
self.running: deque[Sequence] = deque()
|
||||
|
||||
def is_finished(self):
|
||||
"""检查是否所有序列都已完成。"""
|
||||
return not self.waiting and not self.running
|
||||
|
||||
def add(self, seq: Sequence):
|
||||
"""将一个新序列加入等待队列。"""
|
||||
self.waiting.append(seq)
|
||||
|
||||
def schedule(self) -> tuple[list[Sequence], bool]:
|
||||
"""执行一次调度,返回 (被调度的序列列表, 是否为 prefill 阶段)。
|
||||
|
||||
调度逻辑:
|
||||
1. Prefill 阶段:从 waiting 队列中选取序列,检查前缀缓存命中情况,
|
||||
为每个序列计算需要处理的 token 数量。支持 chunked prefill(长 prompt 分多次处理)。
|
||||
2. Decode 阶段:从 running 队列中选取序列,每个序列处理 1 个 token。
|
||||
如果 KV cache 空间不足,会抢占(preempt)最近加入 running 的序列。
|
||||
"""
|
||||
scheduled_seqs = []
|
||||
num_batched_tokens = 0
|
||||
|
||||
# prefill
|
||||
# ========== Prefill 阶段 ==========
|
||||
# 尝试从 waiting 队列中调度序列,计算它们的 prompt KV cache
|
||||
while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
|
||||
seq = self.waiting[0]
|
||||
remaining = self.max_num_batched_tokens - num_batched_tokens
|
||||
if remaining == 0:
|
||||
break
|
||||
|
||||
if not seq.block_table:
|
||||
# 序列尚未分配块,检查前缀缓存和空闲块
|
||||
num_cached_blocks = self.block_manager.can_allocate(seq)
|
||||
if num_cached_blocks == -1:
|
||||
# 空闲块不足,停止调度
|
||||
break
|
||||
# 需要实际处理的 token 数 = 总 prompt token 数 - 缓存命中的 token 数
|
||||
num_tokens = seq.num_tokens - num_cached_blocks * self.block_size
|
||||
else:
|
||||
# 序列已经有块(chunked prefill 的后续分片),只需处理未缓存的 token
|
||||
num_tokens = seq.num_tokens - seq.num_cached_tokens
|
||||
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
|
||||
|
||||
if remaining < num_tokens and scheduled_seqs:
|
||||
# token 预算不足以处理整个序列,且已有其他序列被调度
|
||||
# 注意:第一个序列允许 chunked prefill(remaining < num_tokens 也可以)
|
||||
break
|
||||
|
||||
if not seq.block_table:
|
||||
self.block_manager.allocate(seq, num_cached_blocks)
|
||||
|
||||
# 实际调度的 token 数取 min(num_tokens, remaining),实现 chunked prefill
|
||||
seq.num_scheduled_tokens = min(num_tokens, remaining)
|
||||
num_batched_tokens += seq.num_scheduled_tokens
|
||||
|
||||
if seq.num_cached_tokens + seq.num_scheduled_tokens == seq.num_tokens:
|
||||
# 整个 prompt 已全部处理完毕,转移到 running 队列
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
self.waiting.popleft()
|
||||
self.running.append(seq)
|
||||
|
||||
scheduled_seqs.append(seq)
|
||||
|
||||
if scheduled_seqs:
|
||||
return scheduled_seqs, True
|
||||
return scheduled_seqs, True # is_prefill = True
|
||||
|
||||
# decode
|
||||
# ========== Decode 阶段 ==========
|
||||
# 逐 token 解码,每个序列每次生成 1 个 token
|
||||
while self.running and len(scheduled_seqs) < self.max_num_seqs:
|
||||
seq = self.running.popleft()
|
||||
# 检查是否有空闲块用于存储新的 KV cache
|
||||
while not self.block_manager.can_append(seq):
|
||||
if self.running:
|
||||
# 空间不足,抢占最近加入 running 的序列
|
||||
self.preempt(self.running.pop())
|
||||
else:
|
||||
# 连当前序列都要被抢占
|
||||
self.preempt(seq)
|
||||
break
|
||||
else:
|
||||
@@ -68,21 +113,37 @@ class Scheduler:
|
||||
seq.is_prefill = False
|
||||
self.block_manager.may_append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
assert scheduled_seqs
|
||||
|
||||
assert scheduled_seqs, "No sequences to schedule"
|
||||
# 将调度过的序列放回 running 队列前端(保持顺序)
|
||||
self.running.extendleft(reversed(scheduled_seqs))
|
||||
return scheduled_seqs, False
|
||||
return scheduled_seqs, False # is_prefill = False
|
||||
|
||||
def preempt(self, seq: Sequence):
|
||||
"""抢占一个序列:释放其 KV cache 并放回等待队列头部。
|
||||
|
||||
抢占后序列需要重新做 prefill(重新计算 KV cache)。
|
||||
这是一种牺牲吞吐量来换取 KV cache 空间的策略。
|
||||
"""
|
||||
seq.status = SequenceStatus.WAITING
|
||||
seq.is_prefill = True
|
||||
self.block_manager.deallocate(seq)
|
||||
self.waiting.appendleft(seq)
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool):
|
||||
"""在模型推理完成后处理每个序列的结果。
|
||||
|
||||
1. 更新块的哈希值(用于前缀缓存)。
|
||||
2. 更新已缓存的 token 计数。
|
||||
3. 对于 prefill:如果整个 prompt 还没处理完,继续等待下一次调度。
|
||||
4. 对于完成的序列(prefill 结束后或 decode 中):追加生成的 token。
|
||||
5. 检查终止条件(EOS 或达到 max_tokens),完成的序列被释放资源。
|
||||
"""
|
||||
for seq, token_id in zip(seqs, token_ids):
|
||||
self.block_manager.hash_blocks(seq)
|
||||
seq.num_cached_tokens += seq.num_scheduled_tokens
|
||||
seq.num_scheduled_tokens = 0
|
||||
# prefill 阶段如果还没处理完整个 prompt,不采样 token,继续等待
|
||||
if is_prefill and seq.num_cached_tokens < seq.num_tokens:
|
||||
continue
|
||||
seq.append_token(token_id)
|
||||
|
||||
@@ -6,12 +6,35 @@ from nanovllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(Enum):
|
||||
"""序列的生命周期状态。
|
||||
|
||||
WAITING: 等待被调度器选中进入 prefill 阶段。
|
||||
RUNNING: 正在解码(decode)阶段,逐 token 生成。
|
||||
FINISHED: 生成完成(遇到 EOS 或达到 max_tokens)。
|
||||
"""
|
||||
WAITING = auto()
|
||||
RUNNING = auto()
|
||||
FINISHED = auto()
|
||||
|
||||
|
||||
class Sequence:
|
||||
"""表示一个推理请求的序列。
|
||||
|
||||
每个 Sequence 封装了一个完整的生成请求,从 prompt tokens 到生成的 completion tokens。
|
||||
它是调度器、块管理器和模型运行器之间传递的核心数据结构。
|
||||
|
||||
关键概念:
|
||||
- prompt tokens: 用户输入的原始 token 序列。
|
||||
- completion tokens: 模型生成的 token 序列。
|
||||
- cached tokens: 已经计算过 KV cache 的 token 数量(用于前缀缓存和 chunked prefill)。
|
||||
- scheduled tokens: 当前调度步骤中计划处理的 token 数量。
|
||||
- block_table: 该序列在 KV cache 中占用的物理块 ID 列表。
|
||||
|
||||
类属性:
|
||||
block_size: KV cache 块大小,由 Config 在引擎初始化时设置。
|
||||
counter: 全局自增计数器,用于为每个序列分配唯一 ID。
|
||||
"""
|
||||
|
||||
block_size = 256
|
||||
counter = count()
|
||||
|
||||
@@ -20,12 +43,12 @@ class Sequence:
|
||||
self.status = SequenceStatus.WAITING
|
||||
self.token_ids = copy(token_ids)
|
||||
self.last_token = token_ids[-1]
|
||||
self.num_tokens = len(self.token_ids)
|
||||
self.num_prompt_tokens = len(token_ids)
|
||||
self.num_cached_tokens = 0
|
||||
self.num_scheduled_tokens = 0
|
||||
self.is_prefill = True
|
||||
self.block_table = []
|
||||
self.num_tokens = len(self.token_ids) # 当前总 token 数(prompt + 已生成的)
|
||||
self.num_prompt_tokens = len(token_ids) # prompt 部分的 token 数,生成过程中不变
|
||||
self.num_cached_tokens = 0 # 已计算 KV cache 的 token 数,用于前缀缓存命中判断
|
||||
self.num_scheduled_tokens = 0 # 当前步骤中被调度处理的 token 数
|
||||
self.is_prefill = True # 是否处于 prefill 阶段(首次计算 prompt 的 KV cache)
|
||||
self.block_table = [] # KV cache 物理块 ID 列表,索引为逻辑块号
|
||||
self.temperature = sampling_params.temperature
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
self.ignore_eos = sampling_params.ignore_eos
|
||||
@@ -42,42 +65,56 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def num_completion_tokens(self):
|
||||
"""已生成的 completion token 数量。"""
|
||||
return self.num_tokens - self.num_prompt_tokens
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self):
|
||||
"""返回 prompt 部分的 token ID 列表。"""
|
||||
return self.token_ids[:self.num_prompt_tokens]
|
||||
|
||||
@property
|
||||
def completion_token_ids(self):
|
||||
"""返回已生成的 completion token ID 列表。"""
|
||||
return self.token_ids[self.num_prompt_tokens:]
|
||||
|
||||
@property
|
||||
def num_blocks(self):
|
||||
"""该序列需要的 KV cache 逻辑块数量(向上取整)。"""
|
||||
return (self.num_tokens + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
def last_block_num_tokens(self):
|
||||
"""最后一个块中已使用的 token 数量。"""
|
||||
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
||||
|
||||
def block(self, i):
|
||||
"""获取第 i 个逻辑块对应的 token ID 列表。"""
|
||||
assert 0 <= i < self.num_blocks
|
||||
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
||||
|
||||
def append_token(self, token_id: int):
|
||||
"""将一个新生成的 token 追加到序列末尾。"""
|
||||
self.token_ids.append(token_id)
|
||||
self.last_token = token_id
|
||||
self.num_tokens += 1
|
||||
|
||||
def __getstate__(self):
|
||||
"""序列化时只保存必要的状态,用于多进程间传递序列数据。
|
||||
|
||||
prefill 阶段保存完整 token_ids(模型需要读取全部 prompt tokens),
|
||||
decode 阶段只保存 last_token(模型只需要最新一个 token 的 ID)。
|
||||
"""
|
||||
last_state = self.last_token if not self.is_prefill else self.token_ids
|
||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state)
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""反序列化,恢复序列状态。"""
|
||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state
|
||||
if isinstance(last_state, list):
|
||||
self.token_ids = last_state
|
||||
self.last_token = self.token_ids[-1]
|
||||
else:
|
||||
# decode 阶段不需要完整的 token_ids,只保存了 last_token
|
||||
self.token_ids = []
|
||||
self.last_token = last_state
|
||||
|
||||
Reference in New Issue
Block a user