Merge pull request #218 from GeeeekExplorer/chunked-prefill-refactor

fix chunked prefill bugs and refactor
This commit is contained in:
Xingkai Yu
2026-04-26 13:10:12 +08:00
committed by GitHub
4 changed files with 71 additions and 67 deletions
+48 -40
View File
@@ -40,46 +40,56 @@ class BlockManager:
h.update(np.array(token_ids).tobytes()) h.update(np.array(token_ids).tobytes())
return h.intdigest() return h.intdigest()
def _allocate_block(self, block_id: int) -> Block: def _allocate_block(self) -> int:
block_id = self.free_block_ids.popleft()
block = self.blocks[block_id] block = self.blocks[block_id]
assert block.ref_count == 0 assert block.ref_count == 0
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
block.reset() block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id) self.used_block_ids.add(block_id)
return block return block_id
def _deallocate_block(self, block_id: int) -> Block: def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_count == 0 assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id) self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id) self.free_block_ids.append(block_id)
def can_allocate(self, seq: Sequence) -> bool: def can_allocate(self, seq: Sequence) -> int:
return len(self.free_block_ids) >= seq.num_blocks
def allocate(self, seq: Sequence):
assert not seq.block_table
h = -1 h = -1
cache_miss = False num_cached_blocks = 0
for i in range(seq.num_blocks): num_new_blocks = seq.num_blocks
for i in range(seq.num_blocks - 1):
token_ids = seq.block(i) token_ids = seq.block(i)
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id.get(h, -1) block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids: if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True break
if cache_miss: num_cached_blocks += 1
block_id = self.free_block_ids[0] if block_id in self.used_block_ids:
block = self._allocate_block(block_id) num_new_blocks -= 1
if len(self.free_block_ids) < num_new_blocks:
return -1
return num_cached_blocks
def allocate(self, seq: Sequence, num_cached_blocks: int):
assert not seq.block_table
h = -1
for i in range(num_cached_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id[h]
block = self.blocks[block_id]
if block_id in self.used_block_ids:
block.ref_count += 1
else: else:
seq.num_cached_tokens += self.block_size block.ref_count = 1
if block_id in self.used_block_ids: self.free_block_ids.remove(block_id)
block = self.blocks[block_id] self.used_block_ids.add(block_id)
block.ref_count += 1
else:
block = self._allocate_block(block_id)
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id) seq.block_table.append(block_id)
for i in range(num_cached_blocks, seq.num_blocks):
seq.block_table.append(self._allocate_block())
seq.num_cached_tokens = num_cached_blocks * self.block_size
def deallocate(self, seq: Sequence): def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table): for block_id in reversed(seq.block_table):
@@ -94,19 +104,17 @@ class BlockManager:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence): def may_append(self, seq: Sequence):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
if len(seq) % self.block_size == 1: if len(seq) % self.block_size == 1:
assert last_block.hash != -1 seq.block_table.append(self._allocate_block())
block_id = self.free_block_ids[0]
self._allocate_block(block_id) def hash_blocks(self, seq: Sequence):
block_table.append(block_id) start = seq.num_cached_tokens // self.block_size
elif len(seq) % self.block_size == 0: end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size
assert last_block.hash == -1 if start == end: return
token_ids = seq.block(seq.num_blocks-1) h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 for i in range(start, end):
h = self.compute_hash(token_ids, prefix) block = self.blocks[seq.block_table[i]]
last_block.update(h, token_ids) token_ids = seq.block(i)
self.hash_to_block_id[h] = last_block.block_id h = self.compute_hash(token_ids, h)
else: block.update(h, token_ids)
assert last_block.hash == -1 self.hash_to_block_id[h] = block.block_id
+1 -2
View File
@@ -136,8 +136,7 @@ class ModelRunner:
slot_mapping = [] slot_mapping = []
block_tables = None block_tables = None
for seq in seqs: for seq in seqs:
seqlen = len(seq) start = seq.num_cached_tokens
start = min(seq.num_cached_tokens, seqlen - 1)
seqlen_q = seq.num_scheduled_tokens seqlen_q = seq.num_scheduled_tokens
end = start + seqlen_q end = start + seqlen_q
seqlen_k = end seqlen_k = end
+19 -23
View File
@@ -11,6 +11,7 @@ class Scheduler:
self.max_num_seqs = config.max_num_seqs self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos self.eos = config.eos
self.block_size = config.kvcache_block_size
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size) self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque() self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque() self.running: deque[Sequence] = deque()
@@ -29,31 +30,26 @@ class Scheduler:
while self.waiting and len(scheduled_seqs) < self.max_num_seqs: while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
seq = self.waiting[0] seq = self.waiting[0]
remaining = self.max_num_batched_tokens - num_batched_tokens remaining = self.max_num_batched_tokens - num_batched_tokens
if remaining == 0 or (not seq.block_table and not self.block_manager.can_allocate(seq)): if remaining == 0:
break break
if not seq.block_table: if not seq.block_table:
self.block_manager.allocate(seq) num_cached_blocks = self.block_manager.can_allocate(seq)
if num_cached_blocks == -1:
# Re-calculate num_tokens after allocate(), as prefix caching may update break
# seq.num_cached_tokens during the allocation process. num_tokens = seq.num_tokens - num_cached_blocks * self.block_size
# else:
# Using an outdated num_cached_tokens would overestimate num_scheduled_tokens, num_tokens = seq.num_tokens - seq.num_cached_tokens
# leading to an inflated 'end' and 'end_block' in prepare_prefill (model_runner.py).
# This results in an 'index out of range' at line 155 when accessing
# seq.block_table[i] beyond its actual physical allocation.
num_tokens = max(seq.num_tokens - seq.num_cached_tokens, 1)
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
break break
if not seq.block_table:
self.block_manager.allocate(seq, num_cached_blocks)
seq.num_scheduled_tokens = min(num_tokens, remaining) seq.num_scheduled_tokens = min(num_tokens, remaining)
if seq.num_scheduled_tokens == num_tokens: num_batched_tokens += seq.num_scheduled_tokens
if seq.num_cached_tokens + seq.num_scheduled_tokens == seq.num_tokens:
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
self.waiting.popleft() self.waiting.popleft()
self.running.append(seq) self.running.append(seq)
scheduled_seqs.append(seq) scheduled_seqs.append(seq)
num_batched_tokens += seq.num_scheduled_tokens
if scheduled_seqs: if scheduled_seqs:
return scheduled_seqs, True return scheduled_seqs, True
@@ -69,6 +65,7 @@ class Scheduler:
break break
else: else:
seq.num_scheduled_tokens = 1 seq.num_scheduled_tokens = 1
seq.is_prefill = False
self.block_manager.may_append(seq) self.block_manager.may_append(seq)
scheduled_seqs.append(seq) scheduled_seqs.append(seq)
assert scheduled_seqs assert scheduled_seqs
@@ -77,19 +74,18 @@ class Scheduler:
def preempt(self, seq: Sequence): def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING seq.status = SequenceStatus.WAITING
seq.is_prefill = True
self.block_manager.deallocate(seq) self.block_manager.deallocate(seq)
self.waiting.appendleft(seq) self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool): def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool):
for seq, token_id in zip(seqs, token_ids): for seq, token_id in zip(seqs, token_ids):
if is_prefill: self.block_manager.hash_blocks(seq)
seq.num_cached_tokens = min(seq.num_cached_tokens + seq.num_scheduled_tokens, seq.num_tokens) seq.num_cached_tokens += seq.num_scheduled_tokens
if seq.num_cached_tokens < seq.num_tokens or seq.num_completion_tokens > 0: # chunked prefill or re prefill after preemption
seq.num_scheduled_tokens = 0
continue
seq.append_token(token_id)
seq.num_cached_tokens += 1
seq.num_scheduled_tokens = 0 seq.num_scheduled_tokens = 0
if is_prefill and seq.num_cached_tokens < seq.num_tokens:
continue
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq) self.block_manager.deallocate(seq)
+3 -2
View File
@@ -22,8 +22,9 @@ class Sequence:
self.last_token = token_ids[-1] self.last_token = token_ids[-1]
self.num_tokens = len(self.token_ids) self.num_tokens = len(self.token_ids)
self.num_prompt_tokens = len(token_ids) self.num_prompt_tokens = len(token_ids)
self.num_cached_tokens = 0 # tokens that don't need prefill self.num_cached_tokens = 0
self.num_scheduled_tokens = 0 self.num_scheduled_tokens = 0
self.is_prefill = True
self.block_table = [] self.block_table = []
self.temperature = sampling_params.temperature self.temperature = sampling_params.temperature
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
@@ -69,7 +70,7 @@ class Sequence:
self.num_tokens += 1 self.num_tokens += 1
def __getstate__(self): def __getstate__(self):
last_state = self.token_ids if self.num_completion_tokens == 0 or self.num_cached_tokens < self.num_tokens else self.last_token last_state = self.last_token if not self.is_prefill else self.token_ids
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state) return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state)
def __setstate__(self, state): def __setstate__(self, state):