Merge pull request #145 from LiaoMengqi/fix/tp
bug for tensor parallelism # issue 144
This commit is contained in:
@@ -88,6 +88,7 @@ class BlockManager:
|
||||
if block.ref_count == 0:
|
||||
self._deallocate_block(block_id)
|
||||
seq.num_cached_tokens = 0
|
||||
seq.prefilled = False
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
|
||||
@@ -27,6 +27,7 @@ class Sequence:
|
||||
self.temperature = sampling_params.temperature
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
self.ignore_eos = sampling_params.ignore_eos
|
||||
self.prefilled = False
|
||||
|
||||
def __len__(self):
|
||||
return self.num_tokens
|
||||
@@ -70,14 +71,15 @@ class Sequence:
|
||||
self.token_ids.append(token_id)
|
||||
self.last_token = token_id
|
||||
self.num_tokens += 1
|
||||
self.prefilled = True
|
||||
|
||||
def __getstate__(self):
|
||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
||||
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled,
|
||||
self.last_token if self.prefilled else self.token_ids)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
||||
if self.num_completion_tokens == 0:
|
||||
self.token_ids = state[-1]
|
||||
else:
|
||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1]
|
||||
if self.prefilled:
|
||||
self.last_token = state[-1]
|
||||
else:
|
||||
self.token_ids = state[-1]
|
||||
|
||||
Reference in New Issue
Block a user