support chunked prefill and fix minor bug
This commit is contained in:
@@ -26,7 +26,7 @@ class ModelRunner:
|
||||
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
||||
torch.cuda.set_device(rank)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(hf_config.torch_dtype)
|
||||
torch.set_default_dtype(hf_config.dtype)
|
||||
torch.set_default_device("cuda")
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
load_model(self.model, config.model)
|
||||
@@ -92,8 +92,11 @@ class ModelRunner:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
||||
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
||||
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
||||
seq_len = min(max_num_batched_tokens, max_model_len)
|
||||
num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs)
|
||||
seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)]
|
||||
for seq in seqs:
|
||||
seq.num_scheduled_tokens = seq_len
|
||||
self.run(seqs, True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -106,7 +109,7 @@ class ModelRunner:
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
|
||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert config.num_kvcache_blocks > 0
|
||||
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
||||
@@ -134,23 +137,29 @@ class ModelRunner:
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
seqlen = len(seq)
|
||||
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
||||
seqlen_q = seqlen - seq.num_cached_tokens
|
||||
start = min(seq.num_cached_tokens, seqlen - 1)
|
||||
seqlen_q = seq.num_scheduled_tokens
|
||||
seqlen_k = seqlen
|
||||
end = start + seqlen_q
|
||||
input_ids.extend(seq[start:end])
|
||||
positions.extend(range(start, end))
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
if not seq.block_table: # warmup
|
||||
continue
|
||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||
start = seq.block_table[i] * self.block_size
|
||||
if i != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
start_block = start // self.block_size
|
||||
end_block = (end + self.block_size - 1) // self.block_size
|
||||
for i in range(start_block, end_block):
|
||||
slot_start = seq.block_table[i] * self.block_size
|
||||
if i == start_block:
|
||||
slot_start += start % self.block_size
|
||||
if i != end_block - 1:
|
||||
slot_end = seq.block_table[i] * self.block_size + self.block_size
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size
|
||||
slot_mapping.extend(range(slot_start, slot_end))
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
Reference in New Issue
Block a user