fix bug for tp

This commit is contained in:
Mengqi
2025-12-18 01:28:25 +08:00
parent 2f21442653
commit 82f5ca244f
2 changed files with 9 additions and 6 deletions
+1
View File
@@ -88,6 +88,7 @@ class BlockManager:
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.prefilled = False
seq.block_table.clear()
def can_append(self, seq: Sequence) -> bool:
+8 -6
View File
@@ -27,6 +27,7 @@ class Sequence:
self.temperature = sampling_params.temperature
self.max_tokens = sampling_params.max_tokens
self.ignore_eos = sampling_params.ignore_eos
self.prefilled = False
def __len__(self):
return self.num_tokens
@@ -70,14 +71,15 @@ class Sequence:
self.token_ids.append(token_id)
self.last_token = token_id
self.num_tokens += 1
self.prefilled = True
def __getstate__(self):
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled,
self.last_token if self.prefilled else self.token_ids)
def __setstate__(self, state):
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
if self.num_completion_tokens == 0:
self.token_ids = state[-1]
else:
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1]
if self.prefilled:
self.last_token = state[-1]
else:
self.token_ids = state[-1]