diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 65d725e..763aae2 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -88,6 +88,7 @@ class BlockManager: if block.ref_count == 0: self._deallocate_block(block_id) seq.num_cached_tokens = 0 + seq.prefilled = False seq.block_table.clear() def can_append(self, seq: Sequence) -> bool: diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 49d9ee6..e12f30d 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -27,6 +27,7 @@ class Sequence: self.temperature = sampling_params.temperature self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos + self.prefilled = False def __len__(self): return self.num_tokens @@ -70,14 +71,15 @@ class Sequence: self.token_ids.append(token_id) self.last_token = token_id self.num_tokens += 1 + self.prefilled = True def __getstate__(self): - return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, - self.token_ids if self.num_completion_tokens == 0 else self.last_token) + return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled, + self.last_token if self.prefilled else self.token_ids) def __setstate__(self, state): - self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1] - if self.num_completion_tokens == 0: - self.token_ids = state[-1] - else: + self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1] + if self.prefilled: self.last_token = state[-1] + else: + self.token_ids = state[-1]