From f16adb729e2602f5d015ab38fb4d5882ccdd6c20 Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Thu, 12 Jun 2025 09:41:12 +0800 Subject: [PATCH] refactor --- nanovllm/config.py | 2 +- nanovllm/engine/block_manager.py | 3 --- nanovllm/engine/llm_engine.py | 21 +++++++++++---------- nanovllm/engine/scheduler.py | 10 ++-------- 4 files changed, 14 insertions(+), 22 deletions(-) diff --git a/nanovllm/config.py b/nanovllm/config.py index ae76ade..ce9446b 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -8,7 +8,7 @@ class Config: max_num_batched_tokens: int = 32768 max_num_seqs: int = 512 max_model_len: int = 4096 - gpu_memory_utilization: float = 0.95 + gpu_memory_utilization: float = 0.9 enforce_eager: bool = False hf_config: AutoConfig | None = None eos: int = -1 diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 72a0eef..fde8b37 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -57,9 +57,6 @@ class BlockManager: self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) - def can_prefill(self): - return len(self.free_block_ids) > 0.1 * len(self.blocks) - def can_allocate(self, seq: Sequence): return len(self.free_block_ids) >= seq.num_blocks diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 8464885..eaa16f8 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -1,11 +1,10 @@ -from collections import defaultdict from time import perf_counter from tqdm.auto import tqdm from transformers import AutoConfig, AutoTokenizer from nanovllm.config import Config from nanovllm.sampling_params import SamplingParams -from nanovllm.engine.sequence import Sequence +from nanovllm.engine.sequence import Sequence, SequenceStatus from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.model_runner import ModelRunner @@ -34,8 +33,10 @@ class LLMEngine: def step(self): seqs, is_prefill = self.scheduler.schedule() token_ids = self.model_runner.run(seqs, is_prefill) - finished = self.scheduler.postprocess(seqs, token_ids) - return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)], sum(len(seq) for seq in seqs) if is_prefill else len(seqs) + self.scheduler.postprocess(seqs, token_ids) + outputs = [(seq.seq_id, seq[seq.num_prompt_tokens:]) for seq in seqs if seq.status == SequenceStatus.FINISHED] + num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + return outputs, num_tokens def is_finished(self): return self.scheduler.is_finished() @@ -56,23 +57,23 @@ class LLMEngine: sampling_params = [sampling_params] * len(prompts) for prompt, sp in zip(prompts, sampling_params): self.add_request(prompt, sp) - outputs = defaultdict(list) + outputs = {} prefill_throughput = decode_throughput = 0. while not self.is_finished(): t = perf_counter() output, num_tokens = self.step() if use_tqdm: - if num_tokens > len(output): + if num_tokens > 0: prefill_throughput = num_tokens / (perf_counter() - t) else: - decode_throughput = num_tokens / (perf_counter() - t) + decode_throughput = -num_tokens / (perf_counter() - t) pbar.set_postfix({ "Prefill": f"{int(prefill_throughput)}tok/s", "Decode": f"{int(decode_throughput)}tok/s", }) - for seq_id, token_id, finish in output: - outputs[seq_id].append(token_id) - if finish and use_tqdm: + for seq_id, token_ids in output: + outputs[seq_id] = token_ids + if use_tqdm: pbar.update(1) outputs = [outputs[seq_id] for seq_id in sorted(outputs)] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index cb8bfd1..6d1eb53 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -43,7 +43,6 @@ class Scheduler: return scheduled_seqs, True # decode - # self.running = deque(sorted(self.running)) while self.running and num_seqs < self.max_num_seqs: seq = self.running.popleft() while not self.block_manager.can_append(seq): @@ -59,8 +58,8 @@ class Scheduler: running = deque(scheduled_seqs) running.extend(self.running) self.running = running - if scheduled_seqs: - return scheduled_seqs, False + assert scheduled_seqs + return scheduled_seqs, False def preempt(self, seq: Sequence): seq.status = SequenceStatus.WAITING @@ -69,7 +68,6 @@ class Scheduler: def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: self.num_tokens += len(token_ids) - finished = [] for seq, token_id in zip(seqs, token_ids): seq.append_token(token_id) if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: @@ -77,7 +75,3 @@ class Scheduler: self.block_manager.deallocate(seq) self.running.remove(seq) self.num_finished += 1 - finished.append(True) - else: - finished.append(False) - return finished