ffd2defdfc
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
202 lines
8.4 KiB
Python
202 lines
8.4 KiB
Python
from collections import deque
|
|
import xxhash
|
|
import numpy as np
|
|
|
|
from nanovllm.engine.sequence import Sequence
|
|
|
|
|
|
class Block:
|
|
"""KV cache 的物理块。
|
|
|
|
每个 Block 对应 GPU 显存中一块固定大小的 KV cache 存储,可以存放 block_size 个 token 的 KV 向量。
|
|
|
|
Attributes:
|
|
block_id: 物理块的唯一 ID。
|
|
ref_count: 引用计数,有多少个序列正在共享这个块(前缀缓存复用)。
|
|
hash: 该块内容的哈希值,用于前缀缓存查找。初始为 -1 表示未计算哈希。
|
|
token_ids: 该块对应的 token ID 列表,用于验证缓存命中时内容是否一致。
|
|
"""
|
|
|
|
def __init__(self, block_id):
|
|
self.block_id = block_id
|
|
self.ref_count = 0
|
|
self.hash = -1
|
|
self.token_ids = []
|
|
|
|
def update(self, hash: int, token_ids: list[int]):
|
|
"""更新块的内容哈希和 token_ids。"""
|
|
self.hash = hash
|
|
self.token_ids = token_ids
|
|
|
|
def reset(self):
|
|
"""重置块状态,准备被重新分配。"""
|
|
self.ref_count = 1
|
|
self.hash = -1
|
|
self.token_ids = []
|
|
|
|
|
|
class BlockManager:
|
|
"""管理 KV cache 物理块的分配、释放和前缀缓存。
|
|
|
|
核心设计:
|
|
- 将 KV cache 分成固定大小的物理块(block_size 个 token/块)。
|
|
- 使用哈希表实现前缀缓存(prefix caching):不同序列如果拥有相同的 prompt 前缀,
|
|
可以共享同一组 KV cache 块,避免重复计算。
|
|
- 块通过引用计数(ref_count)管理生命周期,所有引用释放后块变为空闲可复用。
|
|
|
|
Attributes:
|
|
block_size: 每个块存储的 token 数量。
|
|
blocks: 所有物理块的列表,索引即 block_id。
|
|
hash_to_block_id: 内容哈希 → 块 ID 的映射,用于前缀缓存查找。
|
|
free_block_ids: 空闲块的 ID 队列。
|
|
used_block_ids: 正在使用的块 ID 集合。
|
|
"""
|
|
|
|
def __init__(self, num_blocks: int, block_size: int):
|
|
self.block_size = block_size
|
|
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
|
self.hash_to_block_id: dict[int, int] = dict()
|
|
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
|
self.used_block_ids: set[int] = set()
|
|
|
|
@classmethod
|
|
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
|
"""计算一个块的内容哈希值。
|
|
|
|
哈希是链式的:每个块的哈希值依赖于前一个块的哈希值(prefix 参数),
|
|
这确保了前缀缓存的一致性——只有完全相同的 token 序列前缀才会产生相同的哈希链。
|
|
|
|
Args:
|
|
token_ids: 该块对应的 token ID 列表。
|
|
prefix: 前一个块的哈希值,-1 表示第一个块(无前缀)。
|
|
"""
|
|
h = xxhash.xxh64()
|
|
if prefix != -1:
|
|
h.update(prefix.to_bytes(8, "little"))
|
|
h.update(np.array(token_ids).tobytes())
|
|
return h.intdigest()
|
|
|
|
def _allocate_block(self) -> int:
|
|
"""从空闲池中分配一个物理块。"""
|
|
block_id = self.free_block_ids.popleft()
|
|
block = self.blocks[block_id]
|
|
assert block.ref_count == 0
|
|
# 如果该块之前有哈希记录,先从哈希表中移除
|
|
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
|
|
del self.hash_to_block_id[block.hash]
|
|
block.reset()
|
|
self.used_block_ids.add(block_id)
|
|
return block_id
|
|
|
|
def _deallocate_block(self, block_id: int):
|
|
"""释放一个物理块回空闲池。"""
|
|
assert self.blocks[block_id].ref_count == 0
|
|
self.used_block_ids.remove(block_id)
|
|
self.free_block_ids.append(block_id)
|
|
|
|
def can_allocate(self, seq: Sequence) -> int:
|
|
"""检查是否有足够的空闲块来分配给序列,同时计算前缀缓存命中数。
|
|
|
|
遍历序列的所有逻辑块,逐个计算哈希并与已有缓存比对:
|
|
- 如果哈希匹配且 token 内容一致,则该块可以复用(前缀缓存命中)。
|
|
- 一旦遇到不匹配的块就停止,因为前缀缓存要求连续匹配。
|
|
- 最后检查剩余空闲块是否足够分配未命中部分。
|
|
|
|
Returns:
|
|
前缀缓存命中的块数,如果没有足够的空闲块则返回 -1。
|
|
"""
|
|
h = -1
|
|
num_cached_blocks = 0
|
|
num_new_blocks = seq.num_blocks
|
|
for i in range(seq.num_blocks - 1):
|
|
token_ids = seq.block(i)
|
|
h = self.compute_hash(token_ids, h)
|
|
block_id = self.hash_to_block_id.get(h, -1)
|
|
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
|
break
|
|
num_cached_blocks += 1
|
|
if block_id in self.used_block_ids:
|
|
# 块已在使用中(被其他序列共享),不需要额外分配
|
|
num_new_blocks -= 1
|
|
if len(self.free_block_ids) < num_new_blocks:
|
|
return -1
|
|
return num_cached_blocks
|
|
|
|
def allocate(self, seq: Sequence, num_cached_blocks: int):
|
|
"""为序列分配 KV cache 物理块。
|
|
|
|
前缀缓存命中部分:增加已有块的引用计数。
|
|
未命中部分:从空闲池分配新块。
|
|
|
|
Args:
|
|
seq: 要分配块的序列。
|
|
num_cached_blocks: 前缀缓存命中的块数(由 can_allocate 计算)。
|
|
"""
|
|
assert not seq.block_table
|
|
h = -1
|
|
# 处理缓存命中部分:复用已有块
|
|
for i in range(num_cached_blocks):
|
|
token_ids = seq.block(i)
|
|
h = self.compute_hash(token_ids, h)
|
|
block_id = self.hash_to_block_id[h]
|
|
block = self.blocks[block_id]
|
|
if block_id in self.used_block_ids:
|
|
block.ref_count += 1
|
|
else:
|
|
# 块在空闲池中但哈希匹配(之前被释放但内容未覆盖),重新激活
|
|
block.ref_count = 1
|
|
self.free_block_ids.remove(block_id)
|
|
self.used_block_ids.add(block_id)
|
|
seq.block_table.append(block_id)
|
|
# 分配新块用于未命中部分
|
|
for i in range(num_cached_blocks, seq.num_blocks):
|
|
seq.block_table.append(self._allocate_block())
|
|
seq.num_cached_tokens = num_cached_blocks * self.block_size
|
|
|
|
def deallocate(self, seq: Sequence):
|
|
"""释放序列占用的所有 KV cache 块。
|
|
|
|
递减每个块的引用计数,引用计数归零的块被回收进空闲池。
|
|
注意:遍历顺序为逆序,这是为了优先释放最新分配的块(它们在空闲队列尾部)。
|
|
"""
|
|
for block_id in reversed(seq.block_table):
|
|
block = self.blocks[block_id]
|
|
block.ref_count -= 1
|
|
if block.ref_count == 0:
|
|
self._deallocate_block(block_id)
|
|
seq.num_cached_tokens = 0
|
|
seq.block_table.clear()
|
|
|
|
def can_append(self, seq: Sequence) -> bool:
|
|
"""检查 decode 阶段是否能追加一个 token。
|
|
|
|
当序列当前最后一个块已满时(len(seq) % block_size == 1,即新增 token 是新块的第一个),
|
|
需要分配一个新块。否则只需要写入已有块,不需要额外分配。
|
|
"""
|
|
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
|
|
|
def may_append(self, seq: Sequence):
|
|
"""在 decode 阶段,如果需要则追加一个新块。"""
|
|
if len(seq) % self.block_size == 1:
|
|
seq.block_table.append(self._allocate_block())
|
|
|
|
def hash_blocks(self, seq: Sequence):
|
|
"""更新序列中已完成计算的块的哈希值。
|
|
|
|
在每次调度步骤完成后调用,将新计算完毕的块的 token 内容和哈希值
|
|
注册到哈希表中,以便后续序列复用(前缀缓存)。
|
|
|
|
只处理从 num_cached_tokens 到当前进度的块(即本次新完成的块)。
|
|
"""
|
|
start = seq.num_cached_tokens // self.block_size
|
|
end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size
|
|
if start == end: return
|
|
# 继承前一个块的哈希值作为前缀
|
|
h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1
|
|
for i in range(start, end):
|
|
block = self.blocks[seq.block_table[i]]
|
|
token_ids = seq.block(i)
|
|
h = self.compute_hash(token_ids, h)
|
|
block.update(h, token_ids)
|
|
self.hash_to_block_id[h] = block.block_id
|