add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -18,6 +18,12 @@ def store_kvcache_kernel(
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
):
|
||||
"""Triton kernel:将新计算的 K/V 向量写入 KV cache。
|
||||
|
||||
每个 Triton program 处理一个 token 的 K/V 写入。
|
||||
slot_mapping 指定了该 token 的 K/V 应该写入 cache 的哪个 slot。
|
||||
slot == -1 表示该 token 不需要写入(如 warmup 阶段)。
|
||||
"""
|
||||
idx = tl.program_id(0)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
if slot == -1: return
|
||||
@@ -31,6 +37,15 @@ def store_kvcache_kernel(
|
||||
|
||||
|
||||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||||
"""将新计算的 K/V 向量存储到 KV cache 中。
|
||||
|
||||
Args:
|
||||
key: [N, num_heads, head_dim] 新计算的 K 向量。
|
||||
value: [N, num_heads, head_dim] 新计算的 V 向量。
|
||||
k_cache: KV cache 中 K 的存储区域。
|
||||
v_cache: KV cache 中 V 的存储区域。
|
||||
slot_mapping: [N] 每个 token 对应的 cache slot 索引。
|
||||
"""
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
@@ -41,6 +56,18 @@ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor,
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""注意力层:封装了 Flash Attention 和 KV cache 的交互逻辑。
|
||||
|
||||
支持 two 阶段的注意力计算:
|
||||
- Prefill: 使用 flash_attn_varlen_func(变长序列批量注意力),一次性处理整个 prompt。
|
||||
- Decode: 使用 flash_attn_with_kvcache(带 KV cache 的注意力),逐 token 生成。
|
||||
|
||||
当存在前缀缓存时(block_tables 不为 None),prefill 阶段直接从 KV cache 读取
|
||||
已缓存的 K/V,而不是用当前计算的 K/V。
|
||||
|
||||
Attributes:
|
||||
k_cache, v_cache: 绑定到 ModelRunner 分配的全局 KV cache 的对应层视图。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -54,22 +81,24 @@ class Attention(nn.Module):
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
self.k_cache = self.v_cache = torch.tensor([]) # 占位,由 ModelRunner 分配后绑定
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
context = get_context()
|
||||
k_cache, v_cache = self.k_cache, self.v_cache
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
# 将新计算的 K/V 写入 KV cache
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None: # prefix cache
|
||||
if context.block_tables is not None: # 前缀缓存命中:从 KV cache 读取 K/V
|
||||
k, v = k_cache, v_cache
|
||||
# Flash Attention 变长版:支持不同长度的序列在同一批次中计算
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
else: # Decode 阶段:从 KV cache 中读取所有历史 K/V
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
return o
|
||||
|
||||
Reference in New Issue
Block a user