Files
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

105 lines
4.3 KiB
Python

import torch
from torch import nn
import triton
import triton.language as tl
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
"""Triton kernel:将新计算的 K/V 向量写入 KV cache。
每个 Triton program 处理一个 token 的 K/V 写入。
slot_mapping 指定了该 token 的 K/V 应该写入 cache 的哪个 slot。
slot == -1 表示该 token 不需要写入(如 warmup 阶段)。
"""
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
"""将新计算的 K/V 向量存储到 KV cache 中。
Args:
key: [N, num_heads, head_dim] 新计算的 K 向量。
value: [N, num_heads, head_dim] 新计算的 V 向量。
k_cache: KV cache 中 K 的存储区域。
v_cache: KV cache 中 V 的存储区域。
slot_mapping: [N] 每个 token 对应的 cache slot 索引。
"""
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
class Attention(nn.Module):
"""注意力层:封装了 Flash Attention 和 KV cache 的交互逻辑。
支持 two 阶段的注意力计算:
- Prefill: 使用 flash_attn_varlen_func(变长序列批量注意力),一次性处理整个 prompt。
- Decode: 使用 flash_attn_with_kvcache(带 KV cache 的注意力),逐 token 生成。
当存在前缀缓存时(block_tables 不为 None),prefill 阶段直接从 KV cache 读取
已缓存的 K/V,而不是用当前计算的 K/V。
Attributes:
k_cache, v_cache: 绑定到 ModelRunner 分配的全局 KV cache 的对应层视图。
"""
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([]) # 占位,由 ModelRunner 分配后绑定
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
# 将新计算的 K/V 写入 KV cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # 前缀缓存命中:从 KV cache 读取 K/V
k, v = k_cache, v_cache
# Flash Attention 变长版:支持不同长度的序列在同一批次中计算
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # Decode 阶段:从 KV cache 中读取所有历史 K/V
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o