diff --git a/nanovllm/config.py b/nanovllm/config.py index b999838..7066cbe 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -23,4 +23,3 @@ class Config: assert 1 <= self.tensor_parallel_size <= 8 self.hf_config = AutoConfig.from_pretrained(self.model) self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) - assert self.max_num_batched_tokens >= self.max_model_len diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 1ad00f6..f835a29 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -88,7 +88,6 @@ class BlockManager: if block.ref_count == 0: self._deallocate_block(block_id) seq.num_cached_tokens = 0 - seq.prefilled = False seq.block_table.clear() def can_append(self, seq: Sequence) -> bool: diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 45e75aa..3685094 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -18,7 +18,7 @@ class LLMEngine: config_fields = {field.name for field in fields(Config)} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} config = Config(model, **config_kwargs) - Sequence.set_block_size(config.kvcache_block_size) + Sequence.block_size = config.kvcache_block_size self.ps = [] self.events = [] ctx = mp.get_context("spawn") @@ -48,10 +48,10 @@ class LLMEngine: def step(self): seqs, is_prefill = self.scheduler.schedule() + num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs) token_ids = self.model_runner.call("run", seqs, is_prefill) - self.scheduler.postprocess(seqs, token_ids) + self.scheduler.postprocess(seqs, token_ids, is_prefill) outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] - num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) return outputs, num_tokens def is_finished(self): diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index e70a193..5e6342b 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -26,7 +26,7 @@ class ModelRunner: dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) default_dtype = torch.get_default_dtype() - torch.set_default_dtype(hf_config.torch_dtype) + torch.set_default_dtype(hf_config.dtype) torch.set_default_device("cuda") self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) @@ -92,8 +92,11 @@ class ModelRunner: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len - num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) - seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + seq_len = min(max_num_batched_tokens, max_model_len) + num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs) + seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)] + for seq in seqs: + seq.num_scheduled_tokens = seq_len self.run(seqs, True) torch.cuda.empty_cache() @@ -106,7 +109,7 @@ class ModelRunner: current = torch.cuda.memory_stats()["allocated_bytes.all.current"] num_kv_heads = hf_config.num_key_value_heads // self.world_size head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) - block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize + block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes assert config.num_kvcache_blocks > 0 self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim) @@ -134,23 +137,29 @@ class ModelRunner: block_tables = None for seq in seqs: seqlen = len(seq) - input_ids.extend(seq[seq.num_cached_tokens:]) - positions.extend(list(range(seq.num_cached_tokens, seqlen))) - seqlen_q = seqlen - seq.num_cached_tokens + start = min(seq.num_cached_tokens, seqlen - 1) + seqlen_q = seq.num_scheduled_tokens seqlen_k = seqlen + end = start + seqlen_q + input_ids.extend(seq[start:end]) + positions.extend(range(start, end)) cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) if not seq.block_table: # warmup continue - for i in range(seq.num_cached_blocks, seq.num_blocks): - start = seq.block_table[i] * self.block_size - if i != seq.num_blocks - 1: - end = start + self.block_size + start_block = start // self.block_size + end_block = (end + self.block_size - 1) // self.block_size + for i in range(start_block, end_block): + slot_start = seq.block_table[i] * self.block_size + if i == start_block: + slot_start += start % self.block_size + if i != end_block - 1: + slot_end = seq.block_table[i] * self.block_size + self.block_size else: - end = start + seq.last_block_num_tokens - slot_mapping.extend(list(range(start, end))) + slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size + slot_mapping.extend(range(slot_start, slot_end)) if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index f2eabdf..287dd62 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -22,26 +22,32 @@ class Scheduler: self.waiting.append(seq) def schedule(self) -> tuple[list[Sequence], bool]: - # prefill scheduled_seqs = [] - num_seqs = 0 num_batched_tokens = 0 - while self.waiting and num_seqs < self.max_num_seqs: + + # prefill + while self.waiting and len(scheduled_seqs) < self.max_num_seqs: seq = self.waiting[0] - if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq): + num_tokens = max(seq.num_tokens - seq.num_cached_tokens, 1) + remaining = self.max_num_batched_tokens - num_batched_tokens + if remaining == 0 or (not seq.block_table and not self.block_manager.can_allocate(seq)): # no budget break - num_seqs += 1 - self.block_manager.allocate(seq) - num_batched_tokens += len(seq) - seq.num_cached_tokens - seq.status = SequenceStatus.RUNNING - self.waiting.popleft() - self.running.append(seq) + if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq + break + if not seq.block_table: + self.block_manager.allocate(seq) + seq.num_scheduled_tokens = min(num_tokens, remaining) + if seq.num_scheduled_tokens == num_tokens: + seq.status = SequenceStatus.RUNNING + self.waiting.popleft() + self.running.append(seq) scheduled_seqs.append(seq) + num_batched_tokens += seq.num_scheduled_tokens if scheduled_seqs: return scheduled_seqs, True # decode - while self.running and num_seqs < self.max_num_seqs: + while self.running and len(scheduled_seqs) < self.max_num_seqs: seq = self.running.popleft() while not self.block_manager.can_append(seq): if self.running: @@ -50,7 +56,7 @@ class Scheduler: self.preempt(seq) break else: - num_seqs += 1 + seq.num_scheduled_tokens = 1 self.block_manager.may_append(seq) scheduled_seqs.append(seq) assert scheduled_seqs @@ -62,9 +68,16 @@ class Scheduler: self.block_manager.deallocate(seq) self.waiting.appendleft(seq) - def postprocess(self, seqs: list[Sequence], token_ids: list[int]): + def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool): for seq, token_id in zip(seqs, token_ids): + if is_prefill: + seq.num_cached_tokens = min(seq.num_cached_tokens + seq.num_scheduled_tokens, seq.num_tokens) + if seq.num_cached_tokens < seq.num_tokens or seq.num_completion_tokens > 0: # chunked prefill or re prefill after preemption + seq.num_scheduled_tokens = 0 + continue seq.append_token(token_id) + seq.num_cached_tokens += 1 + seq.num_scheduled_tokens = 0 if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: seq.status = SequenceStatus.FINISHED self.block_manager.deallocate(seq) diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 3168888..d90d149 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -12,13 +12,9 @@ class SequenceStatus(Enum): class Sequence: - block_size: int = 0 # invalid value, will be set by set_block_size + block_size = 256 counter = count() - @classmethod - def set_block_size(cls, block_size: int): - cls.block_size = block_size - def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): self.seq_id = next(Sequence.counter) self.status = SequenceStatus.WAITING @@ -26,12 +22,12 @@ class Sequence: self.last_token = token_ids[-1] self.num_tokens = len(self.token_ids) self.num_prompt_tokens = len(token_ids) - self.num_cached_tokens = 0 + self.num_cached_tokens = 0 # tokens that don't need prefill + self.num_scheduled_tokens = 0 self.block_table = [] self.temperature = sampling_params.temperature self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos - self.prefilled = False def __len__(self): return self.num_tokens @@ -55,10 +51,6 @@ class Sequence: def completion_token_ids(self): return self.token_ids[self.num_prompt_tokens:] - @property - def num_cached_blocks(self): - return self.num_cached_tokens // self.block_size - @property def num_blocks(self): return (self.num_tokens + self.block_size - 1) // self.block_size @@ -75,15 +67,16 @@ class Sequence: self.token_ids.append(token_id) self.last_token = token_id self.num_tokens += 1 - self.prefilled = True def __getstate__(self): - return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled, - self.last_token if self.prefilled else self.token_ids) + last_state = self.token_ids if self.num_completion_tokens == 0 or self.num_cached_tokens < self.num_tokens else self.last_token + return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state) def __setstate__(self, state): - self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1] - if self.prefilled: - self.last_token = state[-1] + self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state + if isinstance(last_state, list): + self.token_ids = last_state + self.last_token = self.token_ids[-1] else: - self.token_ids = state[-1] + self.token_ids = [] + self.last_token = last_state diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py index 998d116..f4747a4 100644 --- a/nanovllm/layers/rotary_embedding.py +++ b/nanovllm/layers/rotary_embedding.py @@ -54,8 +54,6 @@ def get_rope( rotary_dim: int, max_position: int, base: float, - rope_scaling: dict | None = None, ): - assert rope_scaling is None rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) return rotary_emb diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index 5d39e0b..cb147f6 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -23,7 +23,7 @@ class Qwen3Attention(nn.Module): rms_norm_eps: float = 1e-06, qkv_bias: bool = False, rope_theta: float = 10000, - rope_scaling: tuple | None = None, + rope_scaling: dict | None = None, ) -> None: super().__init__() tp_size = dist.get_world_size() @@ -51,12 +51,13 @@ class Qwen3Attention(nn.Module): hidden_size, bias=False, ) + if isinstance(rope_scaling, dict): + rope_theta = rope_scaling.get("rope_theta", rope_theta) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=rope_theta, - rope_scaling=rope_scaling, ) self.attn = Attention( self.num_heads,