Files
nano-vllm/nanovllm/engine/scheduler.py
T
six 77dd709ca1 fix(scheduler): recalculate num_tokens after allocate to prevent IndexError
The scheduler overestimated num_scheduled_tokens because it used an outdated num_cached_tokens before block_manager.allocate(seq) could update it via prefix cache hits. In prepare_prefill (model_runner.py), this caused 'end = start + seqlen_q' to exceed the sequence length, leading to an inflated 'end_block'. Consequently, an 'index out of range' error occurred at line 155 when accessing seq.block_table[i] beyond its actual physical allocation.
2026-04-20 16:34:27 +08:00

97 lines
4.0 KiB
Python

from collections import deque
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.engine.block_manager import BlockManager
class Scheduler:
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
def is_finished(self):
return not self.waiting and not self.running
def add(self, seq: Sequence):
self.waiting.append(seq)
def schedule(self) -> tuple[list[Sequence], bool]:
scheduled_seqs = []
num_batched_tokens = 0
# prefill
while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
seq = self.waiting[0]
remaining = self.max_num_batched_tokens - num_batched_tokens
if remaining == 0 or (not seq.block_table and not self.block_manager.can_allocate(seq)):
break
if not seq.block_table:
self.block_manager.allocate(seq)
# Re-calculate num_tokens after allocate(), as prefix caching may update
# seq.num_cached_tokens during the allocation process.
#
# Using an outdated num_cached_tokens would overestimate num_scheduled_tokens,
# leading to an inflated 'end' and 'end_block' in prepare_prefill (model_runner.py).
# This results in an 'index out of range' at line 155 when accessing
# seq.block_table[i] beyond its actual physical allocation.
num_tokens = max(seq.num_tokens - seq.num_cached_tokens, 1)
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
break
seq.num_scheduled_tokens = min(num_tokens, remaining)
if seq.num_scheduled_tokens == num_tokens:
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)
num_batched_tokens += seq.num_scheduled_tokens
if scheduled_seqs:
return scheduled_seqs, True
# decode
while self.running and len(scheduled_seqs) < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
seq.num_scheduled_tokens = 1
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool):
for seq, token_id in zip(seqs, token_ids):
if is_prefill:
seq.num_cached_tokens = min(seq.num_cached_tokens + seq.num_scheduled_tokens, seq.num_tokens)
if seq.num_cached_tokens < seq.num_tokens or seq.num_completion_tokens > 0: # chunked prefill or re prefill after preemption
seq.num_scheduled_tokens = 0
continue
seq.append_token(token_id)
seq.num_cached_tokens += 1
seq.num_scheduled_tokens = 0
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)