fix bug for tp
This commit is contained in:
@@ -88,6 +88,7 @@ class BlockManager:
|
|||||||
if block.ref_count == 0:
|
if block.ref_count == 0:
|
||||||
self._deallocate_block(block_id)
|
self._deallocate_block(block_id)
|
||||||
seq.num_cached_tokens = 0
|
seq.num_cached_tokens = 0
|
||||||
|
seq.prefilled = False
|
||||||
seq.block_table.clear()
|
seq.block_table.clear()
|
||||||
|
|
||||||
def can_append(self, seq: Sequence) -> bool:
|
def can_append(self, seq: Sequence) -> bool:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class Sequence:
|
|||||||
self.temperature = sampling_params.temperature
|
self.temperature = sampling_params.temperature
|
||||||
self.max_tokens = sampling_params.max_tokens
|
self.max_tokens = sampling_params.max_tokens
|
||||||
self.ignore_eos = sampling_params.ignore_eos
|
self.ignore_eos = sampling_params.ignore_eos
|
||||||
|
self.prefilled = False
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_tokens
|
return self.num_tokens
|
||||||
@@ -70,14 +71,15 @@ class Sequence:
|
|||||||
self.token_ids.append(token_id)
|
self.token_ids.append(token_id)
|
||||||
self.last_token = token_id
|
self.last_token = token_id
|
||||||
self.num_tokens += 1
|
self.num_tokens += 1
|
||||||
|
self.prefilled = True
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled,
|
||||||
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
self.last_token if self.prefilled else self.token_ids)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1]
|
||||||
if self.num_completion_tokens == 0:
|
if self.prefilled:
|
||||||
self.token_ids = state[-1]
|
|
||||||
else:
|
|
||||||
self.last_token = state[-1]
|
self.last_token = state[-1]
|
||||||
|
else:
|
||||||
|
self.token_ids = state[-1]
|
||||||
|
|||||||
Reference in New Issue
Block a user