ffd2defdfc
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
121 lines
4.8 KiB
Python
121 lines
4.8 KiB
Python
from copy import copy
|
||
from enum import Enum, auto
|
||
from itertools import count
|
||
|
||
from nanovllm.sampling_params import SamplingParams
|
||
|
||
|
||
class SequenceStatus(Enum):
|
||
"""序列的生命周期状态。
|
||
|
||
WAITING: 等待被调度器选中进入 prefill 阶段。
|
||
RUNNING: 正在解码(decode)阶段,逐 token 生成。
|
||
FINISHED: 生成完成(遇到 EOS 或达到 max_tokens)。
|
||
"""
|
||
WAITING = auto()
|
||
RUNNING = auto()
|
||
FINISHED = auto()
|
||
|
||
|
||
class Sequence:
|
||
"""表示一个推理请求的序列。
|
||
|
||
每个 Sequence 封装了一个完整的生成请求,从 prompt tokens 到生成的 completion tokens。
|
||
它是调度器、块管理器和模型运行器之间传递的核心数据结构。
|
||
|
||
关键概念:
|
||
- prompt tokens: 用户输入的原始 token 序列。
|
||
- completion tokens: 模型生成的 token 序列。
|
||
- cached tokens: 已经计算过 KV cache 的 token 数量(用于前缀缓存和 chunked prefill)。
|
||
- scheduled tokens: 当前调度步骤中计划处理的 token 数量。
|
||
- block_table: 该序列在 KV cache 中占用的物理块 ID 列表。
|
||
|
||
类属性:
|
||
block_size: KV cache 块大小,由 Config 在引擎初始化时设置。
|
||
counter: 全局自增计数器,用于为每个序列分配唯一 ID。
|
||
"""
|
||
|
||
block_size = 256
|
||
counter = count()
|
||
|
||
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
||
self.seq_id = next(Sequence.counter)
|
||
self.status = SequenceStatus.WAITING
|
||
self.token_ids = copy(token_ids)
|
||
self.last_token = token_ids[-1]
|
||
self.num_tokens = len(self.token_ids) # 当前总 token 数(prompt + 已生成的)
|
||
self.num_prompt_tokens = len(token_ids) # prompt 部分的 token 数,生成过程中不变
|
||
self.num_cached_tokens = 0 # 已计算 KV cache 的 token 数,用于前缀缓存命中判断
|
||
self.num_scheduled_tokens = 0 # 当前步骤中被调度处理的 token 数
|
||
self.is_prefill = True # 是否处于 prefill 阶段(首次计算 prompt 的 KV cache)
|
||
self.block_table = [] # KV cache 物理块 ID 列表,索引为逻辑块号
|
||
self.temperature = sampling_params.temperature
|
||
self.max_tokens = sampling_params.max_tokens
|
||
self.ignore_eos = sampling_params.ignore_eos
|
||
|
||
def __len__(self):
|
||
return self.num_tokens
|
||
|
||
def __getitem__(self, key):
|
||
return self.token_ids[key]
|
||
|
||
@property
|
||
def is_finished(self):
|
||
return self.status == SequenceStatus.FINISHED
|
||
|
||
@property
|
||
def num_completion_tokens(self):
|
||
"""已生成的 completion token 数量。"""
|
||
return self.num_tokens - self.num_prompt_tokens
|
||
|
||
@property
|
||
def prompt_token_ids(self):
|
||
"""返回 prompt 部分的 token ID 列表。"""
|
||
return self.token_ids[:self.num_prompt_tokens]
|
||
|
||
@property
|
||
def completion_token_ids(self):
|
||
"""返回已生成的 completion token ID 列表。"""
|
||
return self.token_ids[self.num_prompt_tokens:]
|
||
|
||
@property
|
||
def num_blocks(self):
|
||
"""该序列需要的 KV cache 逻辑块数量(向上取整)。"""
|
||
return (self.num_tokens + self.block_size - 1) // self.block_size
|
||
|
||
@property
|
||
def last_block_num_tokens(self):
|
||
"""最后一个块中已使用的 token 数量。"""
|
||
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
||
|
||
def block(self, i):
|
||
"""获取第 i 个逻辑块对应的 token ID 列表。"""
|
||
assert 0 <= i < self.num_blocks
|
||
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
||
|
||
def append_token(self, token_id: int):
|
||
"""将一个新生成的 token 追加到序列末尾。"""
|
||
self.token_ids.append(token_id)
|
||
self.last_token = token_id
|
||
self.num_tokens += 1
|
||
|
||
def __getstate__(self):
|
||
"""序列化时只保存必要的状态,用于多进程间传递序列数据。
|
||
|
||
prefill 阶段保存完整 token_ids(模型需要读取全部 prompt tokens),
|
||
decode 阶段只保存 last_token(模型只需要最新一个 token 的 ID)。
|
||
"""
|
||
last_state = self.last_token if not self.is_prefill else self.token_ids
|
||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state)
|
||
|
||
def __setstate__(self, state):
|
||
"""反序列化,恢复序列状态。"""
|
||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state
|
||
if isinstance(last_state, list):
|
||
self.token_ids = last_state
|
||
self.last_token = self.token_ids[-1]
|
||
else:
|
||
# decode 阶段不需要完整的 token_ids,只保存了 last_token
|
||
self.token_ids = []
|
||
self.last_token = last_state
|