support chunked prefill and fix minor bug
This commit is contained in:
@@ -23,4 +23,3 @@ class Config:
|
||||
assert 1 <= self.tensor_parallel_size <= 8
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
@@ -88,7 +88,6 @@ class BlockManager:
|
||||
if block.ref_count == 0:
|
||||
self._deallocate_block(block_id)
|
||||
seq.num_cached_tokens = 0
|
||||
seq.prefilled = False
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
|
||||
@@ -18,7 +18,7 @@ class LLMEngine:
|
||||
config_fields = {field.name for field in fields(Config)}
|
||||
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
||||
config = Config(model, **config_kwargs)
|
||||
Sequence.set_block_size(config.kvcache_block_size)
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
self.ps = []
|
||||
self.events = []
|
||||
ctx = mp.get_context("spawn")
|
||||
@@ -48,10 +48,10 @@ class LLMEngine:
|
||||
|
||||
def step(self):
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
|
||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||
self.scheduler.postprocess(seqs, token_ids)
|
||||
self.scheduler.postprocess(seqs, token_ids, is_prefill)
|
||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
||||
return outputs, num_tokens
|
||||
|
||||
def is_finished(self):
|
||||
|
||||
@@ -26,7 +26,7 @@ class ModelRunner:
|
||||
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
||||
torch.cuda.set_device(rank)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(hf_config.torch_dtype)
|
||||
torch.set_default_dtype(hf_config.dtype)
|
||||
torch.set_default_device("cuda")
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
load_model(self.model, config.model)
|
||||
@@ -92,8 +92,11 @@ class ModelRunner:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
||||
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
||||
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
||||
seq_len = min(max_num_batched_tokens, max_model_len)
|
||||
num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs)
|
||||
seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)]
|
||||
for seq in seqs:
|
||||
seq.num_scheduled_tokens = seq_len
|
||||
self.run(seqs, True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -106,7 +109,7 @@ class ModelRunner:
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
|
||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert config.num_kvcache_blocks > 0
|
||||
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
||||
@@ -134,23 +137,29 @@ class ModelRunner:
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
seqlen = len(seq)
|
||||
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
||||
seqlen_q = seqlen - seq.num_cached_tokens
|
||||
start = min(seq.num_cached_tokens, seqlen - 1)
|
||||
seqlen_q = seq.num_scheduled_tokens
|
||||
seqlen_k = seqlen
|
||||
end = start + seqlen_q
|
||||
input_ids.extend(seq[start:end])
|
||||
positions.extend(range(start, end))
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
if not seq.block_table: # warmup
|
||||
continue
|
||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||
start = seq.block_table[i] * self.block_size
|
||||
if i != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
start_block = start // self.block_size
|
||||
end_block = (end + self.block_size - 1) // self.block_size
|
||||
for i in range(start_block, end_block):
|
||||
slot_start = seq.block_table[i] * self.block_size
|
||||
if i == start_block:
|
||||
slot_start += start % self.block_size
|
||||
if i != end_block - 1:
|
||||
slot_end = seq.block_table[i] * self.block_size + self.block_size
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size
|
||||
slot_mapping.extend(range(slot_start, slot_end))
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
@@ -22,26 +22,32 @@ class Scheduler:
|
||||
self.waiting.append(seq)
|
||||
|
||||
def schedule(self) -> tuple[list[Sequence], bool]:
|
||||
# prefill
|
||||
scheduled_seqs = []
|
||||
num_seqs = 0
|
||||
num_batched_tokens = 0
|
||||
while self.waiting and num_seqs < self.max_num_seqs:
|
||||
|
||||
# prefill
|
||||
while self.waiting and len(scheduled_seqs) < self.max_num_seqs:
|
||||
seq = self.waiting[0]
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
||||
num_tokens = max(seq.num_tokens - seq.num_cached_tokens, 1)
|
||||
remaining = self.max_num_batched_tokens - num_batched_tokens
|
||||
if remaining == 0 or (not seq.block_table and not self.block_manager.can_allocate(seq)): # no budget
|
||||
break
|
||||
num_seqs += 1
|
||||
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
|
||||
break
|
||||
if not seq.block_table:
|
||||
self.block_manager.allocate(seq)
|
||||
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
||||
seq.num_scheduled_tokens = min(num_tokens, remaining)
|
||||
if seq.num_scheduled_tokens == num_tokens:
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
self.waiting.popleft()
|
||||
self.running.append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
num_batched_tokens += seq.num_scheduled_tokens
|
||||
if scheduled_seqs:
|
||||
return scheduled_seqs, True
|
||||
|
||||
# decode
|
||||
while self.running and num_seqs < self.max_num_seqs:
|
||||
while self.running and len(scheduled_seqs) < self.max_num_seqs:
|
||||
seq = self.running.popleft()
|
||||
while not self.block_manager.can_append(seq):
|
||||
if self.running:
|
||||
@@ -50,7 +56,7 @@ class Scheduler:
|
||||
self.preempt(seq)
|
||||
break
|
||||
else:
|
||||
num_seqs += 1
|
||||
seq.num_scheduled_tokens = 1
|
||||
self.block_manager.may_append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
assert scheduled_seqs
|
||||
@@ -62,9 +68,16 @@ class Scheduler:
|
||||
self.block_manager.deallocate(seq)
|
||||
self.waiting.appendleft(seq)
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]):
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool):
|
||||
for seq, token_id in zip(seqs, token_ids):
|
||||
if is_prefill:
|
||||
seq.num_cached_tokens = min(seq.num_cached_tokens + seq.num_scheduled_tokens, seq.num_tokens)
|
||||
if seq.num_cached_tokens < seq.num_tokens or seq.num_completion_tokens > 0: # chunked prefill or re prefill after preemption
|
||||
seq.num_scheduled_tokens = 0
|
||||
continue
|
||||
seq.append_token(token_id)
|
||||
seq.num_cached_tokens += 1
|
||||
seq.num_scheduled_tokens = 0
|
||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.deallocate(seq)
|
||||
|
||||
+11
-18
@@ -12,13 +12,9 @@ class SequenceStatus(Enum):
|
||||
|
||||
|
||||
class Sequence:
|
||||
block_size: int = 0 # invalid value, will be set by set_block_size
|
||||
block_size = 256
|
||||
counter = count()
|
||||
|
||||
@classmethod
|
||||
def set_block_size(cls, block_size: int):
|
||||
cls.block_size = block_size
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
||||
self.seq_id = next(Sequence.counter)
|
||||
self.status = SequenceStatus.WAITING
|
||||
@@ -26,12 +22,12 @@ class Sequence:
|
||||
self.last_token = token_ids[-1]
|
||||
self.num_tokens = len(self.token_ids)
|
||||
self.num_prompt_tokens = len(token_ids)
|
||||
self.num_cached_tokens = 0
|
||||
self.num_cached_tokens = 0 # tokens that don't need prefill
|
||||
self.num_scheduled_tokens = 0
|
||||
self.block_table = []
|
||||
self.temperature = sampling_params.temperature
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
self.ignore_eos = sampling_params.ignore_eos
|
||||
self.prefilled = False
|
||||
|
||||
def __len__(self):
|
||||
return self.num_tokens
|
||||
@@ -55,10 +51,6 @@ class Sequence:
|
||||
def completion_token_ids(self):
|
||||
return self.token_ids[self.num_prompt_tokens:]
|
||||
|
||||
@property
|
||||
def num_cached_blocks(self):
|
||||
return self.num_cached_tokens // self.block_size
|
||||
|
||||
@property
|
||||
def num_blocks(self):
|
||||
return (self.num_tokens + self.block_size - 1) // self.block_size
|
||||
@@ -75,15 +67,16 @@ class Sequence:
|
||||
self.token_ids.append(token_id)
|
||||
self.last_token = token_id
|
||||
self.num_tokens += 1
|
||||
self.prefilled = True
|
||||
|
||||
def __getstate__(self):
|
||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled,
|
||||
self.last_token if self.prefilled else self.token_ids)
|
||||
last_state = self.token_ids if self.num_completion_tokens == 0 or self.num_cached_tokens < self.num_tokens else self.last_token
|
||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1]
|
||||
if self.prefilled:
|
||||
self.last_token = state[-1]
|
||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state
|
||||
if isinstance(last_state, list):
|
||||
self.token_ids = last_state
|
||||
self.last_token = self.token_ids[-1]
|
||||
else:
|
||||
self.token_ids = state[-1]
|
||||
self.token_ids = []
|
||||
self.last_token = last_state
|
||||
|
||||
@@ -54,8 +54,6 @@ def get_rope(
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
):
|
||||
assert rope_scaling is None
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
return rotary_emb
|
||||
|
||||
@@ -23,7 +23,7 @@ class Qwen3Attention(nn.Module):
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: tuple | None = None,
|
||||
rope_scaling: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
tp_size = dist.get_world_size()
|
||||
@@ -51,12 +51,13 @@ class Qwen3Attention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
if isinstance(rope_scaling, dict):
|
||||
rope_theta = rope_scaling.get("rope_theta", rope_theta)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
|
||||
Reference in New Issue
Block a user