minor simplify
This commit is contained in:
@@ -63,8 +63,7 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm)
|
||||
if not isinstance(sampling_params, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
for prompt, sp in zip(prompts, sampling_params):
|
||||
@@ -74,21 +73,18 @@ class LLMEngine:
|
||||
while not self.is_finished():
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
if use_tqdm:
|
||||
if num_tokens > 0:
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
decode_throughput = -num_tokens / (perf_counter() - t)
|
||||
pbar.set_postfix({
|
||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||
"Decode": f"{int(decode_throughput)}tok/s",
|
||||
})
|
||||
if num_tokens > 0:
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
decode_throughput = -num_tokens / (perf_counter() - t)
|
||||
pbar.set_postfix({
|
||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||
"Decode": f"{int(decode_throughput)}tok/s",
|
||||
})
|
||||
for seq_id, token_ids in output:
|
||||
outputs[seq_id] = token_ids
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user