Files
nano-vllm/nanovllm/engine/scheduler.py
T
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

154 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from collections import deque
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.engine.block_manager import BlockManager
class Scheduler:
"""调度器:决定每个步骤(step)中哪些序列被处理以及处理多少 token。
调度策略采用 vLLM 风格的 prefill-decode 分离调度:
1. **Prefill 优先**: 每次调度先尝试处理等待中的序列(计算其 prompt 的 KV cache)。
2. **Chunked prefill**: 如果一个序列的 prompt 太长,可以分多次调度处理。
3. **Decode**: 当没有等待中的序列时(或 prefill token 额度用完),处理正在解码的序列。
4. **抢占(Preemption**: 当 KV cache 空间不足时,将最近的 running 序列抢占回 waiting 队列。
调度约束:
- 总 token 数不超过 max_num_batched_tokensprefill 阶段)。
- 总序列数不超过 max_num_seqs。
Attributes:
waiting: 等待处理的序列队列(FIFO)。
running: 正在解码的序列队列。
"""
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
self.block_size = config.kvcache_block_size
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
def is_finished(self):
"""检查是否所有序列都已完成。"""
return not self.waiting and not self.running
def add(self, seq: Sequence):
"""将一个新序列加入等待队列。"""
self.waiting.append(seq)
def schedule(self) -> tuple[list[Sequence], bool]:
"""执行一次调度,返回 (被调度的序列列表, 是否为 prefill 阶段)。
调度逻辑:
1. Prefill 阶段:从 waiting 队列中选取序列,检查前缀缓存命中情况,
为每个序列计算需要处理的 token 数量。支持 chunked prefill(长 prompt 分多次处理)。
2. Decode 阶段:从 running 队列中选取序列,每个序列处理 1 个 token。
如果 KV cache 空间不足,会抢占(preempt)最近加入 running 的序列。
"""
scheduled_seqs = []
num_batched_tokens = 0
# ========== Prefill 阶段 ==========
# 尝试从 waiting 队列中调度序列,计算它们的 prompt KV cache
while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
seq = self.waiting[0]
remaining = self.max_num_batched_tokens - num_batched_tokens
if remaining == 0:
break
if not seq.block_table:
# 序列尚未分配块,检查前缀缓存和空闲块
num_cached_blocks = self.block_manager.can_allocate(seq)
if num_cached_blocks == -1:
# 空闲块不足,停止调度
break
# 需要实际处理的 token 数 = 总 prompt token 数 - 缓存命中的 token 数
num_tokens = seq.num_tokens - num_cached_blocks * self.block_size
else:
# 序列已经有块(chunked prefill 的后续分片),只需处理未缓存的 token
num_tokens = seq.num_tokens - seq.num_cached_tokens
if remaining < num_tokens and scheduled_seqs:
# token 预算不足以处理整个序列,且已有其他序列被调度
# 注意:第一个序列允许 chunked prefillremaining < num_tokens 也可以)
break
if not seq.block_table:
self.block_manager.allocate(seq, num_cached_blocks)
# 实际调度的 token 数取 min(num_tokens, remaining),实现 chunked prefill
seq.num_scheduled_tokens = min(num_tokens, remaining)
num_batched_tokens += seq.num_scheduled_tokens
if seq.num_cached_tokens + seq.num_scheduled_tokens == seq.num_tokens:
# 整个 prompt 已全部处理完毕,转移到 running 队列
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)
if scheduled_seqs:
return scheduled_seqs, True # is_prefill = True
# ========== Decode 阶段 ==========
# 逐 token 解码,每个序列每次生成 1 个 token
while self.running and len(scheduled_seqs) < self.max_num_seqs:
seq = self.running.popleft()
# 检查是否有空闲块用于存储新的 KV cache
while not self.block_manager.can_append(seq):
if self.running:
# 空间不足,抢占最近加入 running 的序列
self.preempt(self.running.pop())
else:
# 连当前序列都要被抢占
self.preempt(seq)
break
else:
seq.num_scheduled_tokens = 1
seq.is_prefill = False
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs, "No sequences to schedule"
# 将调度过的序列放回 running 队列前端(保持顺序)
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False # is_prefill = False
def preempt(self, seq: Sequence):
"""抢占一个序列:释放其 KV cache 并放回等待队列头部。
抢占后序列需要重新做 prefill(重新计算 KV cache)。
这是一种牺牲吞吐量来换取 KV cache 空间的策略。
"""
seq.status = SequenceStatus.WAITING
seq.is_prefill = True
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool):
"""在模型推理完成后处理每个序列的结果。
1. 更新块的哈希值(用于前缀缓存)。
2. 更新已缓存的 token 计数。
3. 对于 prefill:如果整个 prompt 还没处理完,继续等待下一次调度。
4. 对于完成的序列(prefill 结束后或 decode 中):追加生成的 token。
5. 检查终止条件(EOS 或达到 max_tokens),完成的序列被释放资源。
"""
for seq, token_id in zip(seqs, token_ids):
self.block_manager.hash_blocks(seq)
seq.num_cached_tokens += seq.num_scheduled_tokens
seq.num_scheduled_tokens = 0
# prefill 阶段如果还没处理完整个 prompt,不采样 token,继续等待
if is_prefill and seq.num_cached_tokens < seq.num_tokens:
continue
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)