minor simplify
This commit is contained in:
@@ -46,7 +46,7 @@ class BlockManager:
|
|||||||
block.reset()
|
block.reset()
|
||||||
self.free_block_ids.remove(block_id)
|
self.free_block_ids.remove(block_id)
|
||||||
self.used_block_ids.add(block_id)
|
self.used_block_ids.add(block_id)
|
||||||
return self.blocks[block_id]
|
return block
|
||||||
|
|
||||||
def _deallocate_block(self, block_id: int) -> Block:
|
def _deallocate_block(self, block_id: int) -> Block:
|
||||||
assert self.blocks[block_id].ref_count == 0
|
assert self.blocks[block_id].ref_count == 0
|
||||||
|
|||||||
@@ -63,8 +63,7 @@ class LLMEngine:
|
|||||||
sampling_params: SamplingParams | list[SamplingParams],
|
sampling_params: SamplingParams | list[SamplingParams],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
if use_tqdm:
|
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm)
|
||||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
|
||||||
if not isinstance(sampling_params, list):
|
if not isinstance(sampling_params, list):
|
||||||
sampling_params = [sampling_params] * len(prompts)
|
sampling_params = [sampling_params] * len(prompts)
|
||||||
for prompt, sp in zip(prompts, sampling_params):
|
for prompt, sp in zip(prompts, sampling_params):
|
||||||
@@ -74,21 +73,18 @@ class LLMEngine:
|
|||||||
while not self.is_finished():
|
while not self.is_finished():
|
||||||
t = perf_counter()
|
t = perf_counter()
|
||||||
output, num_tokens = self.step()
|
output, num_tokens = self.step()
|
||||||
if use_tqdm:
|
if num_tokens > 0:
|
||||||
if num_tokens > 0:
|
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
else:
|
||||||
else:
|
decode_throughput = -num_tokens / (perf_counter() - t)
|
||||||
decode_throughput = -num_tokens / (perf_counter() - t)
|
pbar.set_postfix({
|
||||||
pbar.set_postfix({
|
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
"Decode": f"{int(decode_throughput)}tok/s",
|
||||||
"Decode": f"{int(decode_throughput)}tok/s",
|
})
|
||||||
})
|
|
||||||
for seq_id, token_ids in output:
|
for seq_id, token_ids in output:
|
||||||
outputs[seq_id] = token_ids
|
outputs[seq_id] = token_ids
|
||||||
if use_tqdm:
|
pbar.update(1)
|
||||||
pbar.update(1)
|
pbar.close()
|
||||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||||
if use_tqdm:
|
|
||||||
pbar.close()
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -180,9 +180,7 @@ class ModelRunner:
|
|||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def prepare_sample(self, seqs: list[Sequence]):
|
def prepare_sample(self, seqs: list[Sequence]):
|
||||||
temperatures = []
|
temperatures = [seq.temperature for seq in seqs]
|
||||||
for seq in seqs:
|
|
||||||
temperatures.append(seq.temperature)
|
|
||||||
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
||||||
return temperatures
|
return temperatures
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,6 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x, y = x.chunk(2, -1)
|
x, y = x.chunk(2, -1)
|
||||||
|
|||||||
@@ -141,8 +141,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
if param_data.dim() == 1:
|
if param_data.ndim == 1:
|
||||||
# bias is not sharded in RowParallelLinear
|
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
return
|
return
|
||||||
shard_size = param_data.size(self.tp_dim)
|
shard_size = param_data.size(self.tp_dim)
|
||||||
|
|||||||
@@ -4,9 +4,6 @@ from torch import nn
|
|||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||||
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||||
|
|||||||
Reference in New Issue
Block a user