from copy import copy from enum import Enum, auto from itertools import count from nanovllm.sampling_params import SamplingParams class SequenceStatus(Enum): """序列的生命周期状态。 WAITING: 等待被调度器选中进入 prefill 阶段。 RUNNING: 正在解码(decode)阶段,逐 token 生成。 FINISHED: 生成完成(遇到 EOS 或达到 max_tokens)。 """ WAITING = auto() RUNNING = auto() FINISHED = auto() class Sequence: """表示一个推理请求的序列。 每个 Sequence 封装了一个完整的生成请求,从 prompt tokens 到生成的 completion tokens。 它是调度器、块管理器和模型运行器之间传递的核心数据结构。 关键概念: - prompt tokens: 用户输入的原始 token 序列。 - completion tokens: 模型生成的 token 序列。 - cached tokens: 已经计算过 KV cache 的 token 数量(用于前缀缓存和 chunked prefill)。 - scheduled tokens: 当前调度步骤中计划处理的 token 数量。 - block_table: 该序列在 KV cache 中占用的物理块 ID 列表。 类属性: block_size: KV cache 块大小,由 Config 在引擎初始化时设置。 counter: 全局自增计数器,用于为每个序列分配唯一 ID。 """ block_size = 256 counter = count() def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): self.seq_id = next(Sequence.counter) self.status = SequenceStatus.WAITING self.token_ids = copy(token_ids) self.last_token = token_ids[-1] self.num_tokens = len(self.token_ids) # 当前总 token 数(prompt + 已生成的) self.num_prompt_tokens = len(token_ids) # prompt 部分的 token 数,生成过程中不变 self.num_cached_tokens = 0 # 已计算 KV cache 的 token 数,用于前缀缓存命中判断 self.num_scheduled_tokens = 0 # 当前步骤中被调度处理的 token 数 self.is_prefill = True # 是否处于 prefill 阶段(首次计算 prompt 的 KV cache) self.block_table = [] # KV cache 物理块 ID 列表,索引为逻辑块号 self.temperature = sampling_params.temperature self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos def __len__(self): return self.num_tokens def __getitem__(self, key): return self.token_ids[key] @property def is_finished(self): return self.status == SequenceStatus.FINISHED @property def num_completion_tokens(self): """已生成的 completion token 数量。""" return self.num_tokens - self.num_prompt_tokens @property def prompt_token_ids(self): """返回 prompt 部分的 token ID 列表。""" return self.token_ids[:self.num_prompt_tokens] @property def completion_token_ids(self): """返回已生成的 completion token ID 列表。""" return self.token_ids[self.num_prompt_tokens:] @property def num_blocks(self): """该序列需要的 KV cache 逻辑块数量(向上取整)。""" return (self.num_tokens + self.block_size - 1) // self.block_size @property def last_block_num_tokens(self): """最后一个块中已使用的 token 数量。""" return self.num_tokens - (self.num_blocks - 1) * self.block_size def block(self, i): """获取第 i 个逻辑块对应的 token ID 列表。""" assert 0 <= i < self.num_blocks return self.token_ids[i*self.block_size: (i+1)*self.block_size] def append_token(self, token_id: int): """将一个新生成的 token 追加到序列末尾。""" self.token_ids.append(token_id) self.last_token = token_id self.num_tokens += 1 def __getstate__(self): """序列化时只保存必要的状态,用于多进程间传递序列数据。 prefill 阶段保存完整 token_ids(模型需要读取全部 prompt tokens), decode 阶段只保存 last_token(模型只需要最新一个 token 的 ID)。 """ last_state = self.last_token if not self.is_prefill else self.token_ids return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state) def __setstate__(self, state): """反序列化,恢复序列状态。""" self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state if isinstance(last_state, list): self.token_ids = last_state self.last_token = self.token_ids[-1] else: # decode 阶段不需要完整的 token_ids,只保存了 last_token self.token_ids = [] self.last_token = last_state