Merge pull request #204 from GeeeekExplorer/chunked-prefill

Chunked Prefill
This commit is contained in:
Xingkai Yu
2026-04-14 03:06:27 +08:00
committed by GitHub
11 changed files with 79 additions and 80 deletions
-1
View File
@@ -23,4 +23,3 @@ class Config:
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len
+1 -2
View File
@@ -46,7 +46,7 @@ class BlockManager:
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]
return block
def _deallocate_block(self, block_id: int) -> Block:
assert self.blocks[block_id].ref_count == 0
@@ -88,7 +88,6 @@ class BlockManager:
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.prefilled = False
seq.block_table.clear()
def can_append(self, seq: Sequence) -> bool:
+5 -9
View File
@@ -18,7 +18,7 @@ class LLMEngine:
config_fields = {field.name for field in fields(Config)}
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
config = Config(model, **config_kwargs)
Sequence.set_block_size(config.kvcache_block_size)
Sequence.block_size = config.kvcache_block_size
self.ps = []
self.events = []
ctx = mp.get_context("spawn")
@@ -48,10 +48,10 @@ class LLMEngine:
def step(self):
seqs, is_prefill = self.scheduler.schedule()
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
self.scheduler.postprocess(seqs, token_ids, is_prefill)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
return outputs, num_tokens
def is_finished(self):
@@ -63,8 +63,7 @@ class LLMEngine:
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
) -> list[str]:
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm)
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params):
@@ -74,7 +73,6 @@ class LLMEngine:
while not self.is_finished():
t = perf_counter()
output, num_tokens = self.step()
if use_tqdm:
if num_tokens > 0:
prefill_throughput = num_tokens / (perf_counter() - t)
else:
@@ -85,10 +83,8 @@ class LLMEngine:
})
for seq_id, token_ids in output:
outputs[seq_id] = token_ids
if use_tqdm:
pbar.update(1)
pbar.close()
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
if use_tqdm:
pbar.close()
return outputs
+23 -16
View File
@@ -26,7 +26,7 @@ class ModelRunner:
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_dtype(hf_config.dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
@@ -92,8 +92,11 @@ class ModelRunner:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
seq_len = min(max_num_batched_tokens, max_model_len)
num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs)
seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)]
for seq in seqs:
seq.num_scheduled_tokens = seq_len
self.run(seqs, True)
torch.cuda.empty_cache()
@@ -106,7 +109,7 @@ class ModelRunner:
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
@@ -134,23 +137,29 @@ class ModelRunner:
block_tables = None
for seq in seqs:
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
start = min(seq.num_cached_tokens, seqlen - 1)
seqlen_q = seq.num_scheduled_tokens
seqlen_k = seqlen
end = start + seqlen_q
input_ids.extend(seq[start:end])
positions.extend(range(start, end))
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup
continue
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
start_block = start // self.block_size
end_block = (end + self.block_size - 1) // self.block_size
for i in range(start_block, end_block):
slot_start = seq.block_table[i] * self.block_size
if i == start_block:
slot_start += start % self.block_size
if i != end_block - 1:
slot_end = seq.block_table[i] * self.block_size + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size
slot_mapping.extend(range(slot_start, slot_end))
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
@@ -180,9 +189,7 @@ class ModelRunner:
return input_ids, positions
def prepare_sample(self, seqs: list[Sequence]):
temperatures = []
for seq in seqs:
temperatures.append(seq.temperature)
temperatures = [seq.temperature for seq in seqs]
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures
+22 -9
View File
@@ -22,26 +22,32 @@ class Scheduler:
self.waiting.append(seq)
def schedule(self) -> tuple[list[Sequence], bool]:
# prefill
scheduled_seqs = []
num_seqs = 0
num_batched_tokens = 0
while self.waiting and num_seqs < self.max_num_seqs:
# prefill
while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
seq = self.waiting[0]
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
num_tokens = max(seq.num_tokens - seq.num_cached_tokens, 1)
remaining = self.max_num_batched_tokens - num_batched_tokens
if remaining == 0 or (not seq.block_table and not self.block_manager.can_allocate(seq)): # no budget
break
num_seqs += 1
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
break
if not seq.block_table:
self.block_manager.allocate(seq)
num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.num_scheduled_tokens = min(num_tokens, remaining)
if seq.num_scheduled_tokens == num_tokens:
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)
num_batched_tokens += seq.num_scheduled_tokens
if scheduled_seqs:
return scheduled_seqs, True
# decode
while self.running and num_seqs < self.max_num_seqs:
while self.running and len(scheduled_seqs) < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
if self.running:
@@ -50,7 +56,7 @@ class Scheduler:
self.preempt(seq)
break
else:
num_seqs += 1
seq.num_scheduled_tokens = 1
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
@@ -62,9 +68,16 @@ class Scheduler:
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int]):
def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool):
for seq, token_id in zip(seqs, token_ids):
if is_prefill:
seq.num_cached_tokens = min(seq.num_cached_tokens + seq.num_scheduled_tokens, seq.num_tokens)
if seq.num_cached_tokens < seq.num_tokens or seq.num_completion_tokens > 0: # chunked prefill or re prefill after preemption
seq.num_scheduled_tokens = 0
continue
seq.append_token(token_id)
seq.num_cached_tokens += 1
seq.num_scheduled_tokens = 0
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
+11 -18
View File
@@ -12,13 +12,9 @@ class SequenceStatus(Enum):
class Sequence:
block_size: int = 0 # invalid value, will be set by set_block_size
block_size = 256
counter = count()
@classmethod
def set_block_size(cls, block_size: int):
cls.block_size = block_size
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
self.seq_id = next(Sequence.counter)
self.status = SequenceStatus.WAITING
@@ -26,12 +22,12 @@ class Sequence:
self.last_token = token_ids[-1]
self.num_tokens = len(self.token_ids)
self.num_prompt_tokens = len(token_ids)
self.num_cached_tokens = 0
self.num_cached_tokens = 0 # tokens that don't need prefill
self.num_scheduled_tokens = 0
self.block_table = []
self.temperature = sampling_params.temperature
self.max_tokens = sampling_params.max_tokens
self.ignore_eos = sampling_params.ignore_eos
self.prefilled = False
def __len__(self):
return self.num_tokens
@@ -55,10 +51,6 @@ class Sequence:
def completion_token_ids(self):
return self.token_ids[self.num_prompt_tokens:]
@property
def num_cached_blocks(self):
return self.num_cached_tokens // self.block_size
@property
def num_blocks(self):
return (self.num_tokens + self.block_size - 1) // self.block_size
@@ -75,15 +67,16 @@ class Sequence:
self.token_ids.append(token_id)
self.last_token = token_id
self.num_tokens += 1
self.prefilled = True
def __getstate__(self):
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled,
self.last_token if self.prefilled else self.token_ids)
last_state = self.token_ids if self.num_completion_tokens == 0 or self.num_cached_tokens < self.num_tokens else self.last_token
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state)
def __setstate__(self, state):
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1]
if self.prefilled:
self.last_token = state[-1]
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state
if isinstance(last_state, list):
self.token_ids = last_state
self.last_token = self.token_ids[-1]
else:
self.token_ids = state[-1]
self.token_ids = []
self.last_token = last_state
-3
View File
@@ -5,9 +5,6 @@ import torch.nn.functional as F
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
+1 -2
View File
@@ -141,8 +141,7 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
if param_data.dim() == 1:
# bias is not sharded in RowParallelLinear
if param_data.ndim == 1:
param_data.copy_(loaded_weight)
return
shard_size = param_data.size(self.tp_dim)
-2
View File
@@ -54,8 +54,6 @@ def get_rope(
rotary_dim: int,
max_position: int,
base: float,
rope_scaling: dict | None = None,
):
assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb
-3
View File
@@ -4,9 +4,6 @@ from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
+3 -2
View File
@@ -23,7 +23,7 @@ class Qwen3Attention(nn.Module):
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
rope_scaling: dict | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
@@ -51,12 +51,13 @@ class Qwen3Attention(nn.Module):
hidden_size,
bias=False,
)
if isinstance(rope_scaling, dict):
rope_theta = rope_scaling.get("rope_theta", rope_theta)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,