add Chinese annotations to all source files for learning purposes

Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 21:33:15 +08:00
parent bb823b3e06
commit ffd2defdfc
19 changed files with 656 additions and 34 deletions
+33 -4
View File
@@ -18,6 +18,12 @@ def store_kvcache_kernel(
slot_mapping_ptr,
D: tl.constexpr,
):
"""Triton kernel:将新计算的 K/V 向量写入 KV cache。
每个 Triton program 处理一个 token 的 K/V 写入。
slot_mapping 指定了该 token 的 K/V 应该写入 cache 的哪个 slot。
slot == -1 表示该 token 不需要写入(如 warmup 阶段)。
"""
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
@@ -31,6 +37,15 @@ def store_kvcache_kernel(
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
"""将新计算的 K/V 向量存储到 KV cache 中。
Args:
key: [N, num_heads, head_dim] 新计算的 K 向量。
value: [N, num_heads, head_dim] 新计算的 V 向量。
k_cache: KV cache 中 K 的存储区域。
v_cache: KV cache 中 V 的存储区域。
slot_mapping: [N] 每个 token 对应的 cache slot 索引。
"""
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
@@ -41,6 +56,18 @@ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor,
class Attention(nn.Module):
"""注意力层:封装了 Flash Attention 和 KV cache 的交互逻辑。
支持 two 阶段的注意力计算:
- Prefill: 使用 flash_attn_varlen_func(变长序列批量注意力),一次性处理整个 prompt。
- Decode: 使用 flash_attn_with_kvcache(带 KV cache 的注意力),逐 token 生成。
当存在前缀缓存时(block_tables 不为 None),prefill 阶段直接从 KV cache 读取
已缓存的 K/V,而不是用当前计算的 K/V。
Attributes:
k_cache, v_cache: 绑定到 ModelRunner 分配的全局 KV cache 的对应层视图。
"""
def __init__(
self,
@@ -54,22 +81,24 @@ class Attention(nn.Module):
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
self.k_cache = self.v_cache = torch.tensor([]) # 占位,由 ModelRunner 分配后绑定
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
# 将新计算的 K/V 写入 KV cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # prefix cache
if context.block_tables is not None: # 前缀缓存命中:从 KV cache 读取 K/V
k, v = k_cache, v_cache
# Flash Attention 变长版:支持不同长度的序列在同一批次中计算
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
else: # Decode 阶段:从 KV cache 中读取所有历史 K/V
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o