ffd2defdfc
Annotated 16 source files covering the full architecture: engine (scheduler, block manager, model runner), layers (attention, linear, sampler, etc.), model (qwen3), and utils. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
105 lines
4.3 KiB
Python
105 lines
4.3 KiB
Python
import torch
|
|
from torch import nn
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
|
from nanovllm.utils.context import get_context
|
|
|
|
|
|
@triton.jit
|
|
def store_kvcache_kernel(
|
|
key_ptr,
|
|
key_stride,
|
|
value_ptr,
|
|
value_stride,
|
|
k_cache_ptr,
|
|
v_cache_ptr,
|
|
slot_mapping_ptr,
|
|
D: tl.constexpr,
|
|
):
|
|
"""Triton kernel:将新计算的 K/V 向量写入 KV cache。
|
|
|
|
每个 Triton program 处理一个 token 的 K/V 写入。
|
|
slot_mapping 指定了该 token 的 K/V 应该写入 cache 的哪个 slot。
|
|
slot == -1 表示该 token 不需要写入(如 warmup 阶段)。
|
|
"""
|
|
idx = tl.program_id(0)
|
|
slot = tl.load(slot_mapping_ptr + idx)
|
|
if slot == -1: return
|
|
key_offsets = idx * key_stride + tl.arange(0, D)
|
|
value_offsets = idx * value_stride + tl.arange(0, D)
|
|
key = tl.load(key_ptr + key_offsets)
|
|
value = tl.load(value_ptr + value_offsets)
|
|
cache_offsets = slot * D + tl.arange(0, D)
|
|
tl.store(k_cache_ptr + cache_offsets, key)
|
|
tl.store(v_cache_ptr + cache_offsets, value)
|
|
|
|
|
|
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
|
"""将新计算的 K/V 向量存储到 KV cache 中。
|
|
|
|
Args:
|
|
key: [N, num_heads, head_dim] 新计算的 K 向量。
|
|
value: [N, num_heads, head_dim] 新计算的 V 向量。
|
|
k_cache: KV cache 中 K 的存储区域。
|
|
v_cache: KV cache 中 V 的存储区域。
|
|
slot_mapping: [N] 每个 token 对应的 cache slot 索引。
|
|
"""
|
|
N, num_heads, head_dim = key.shape
|
|
D = num_heads * head_dim
|
|
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
|
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
|
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
|
assert slot_mapping.numel() == N
|
|
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""注意力层:封装了 Flash Attention 和 KV cache 的交互逻辑。
|
|
|
|
支持 two 阶段的注意力计算:
|
|
- Prefill: 使用 flash_attn_varlen_func(变长序列批量注意力),一次性处理整个 prompt。
|
|
- Decode: 使用 flash_attn_with_kvcache(带 KV cache 的注意力),逐 token 生成。
|
|
|
|
当存在前缀缓存时(block_tables 不为 None),prefill 阶段直接从 KV cache 读取
|
|
已缓存的 K/V,而不是用当前计算的 K/V。
|
|
|
|
Attributes:
|
|
k_cache, v_cache: 绑定到 ModelRunner 分配的全局 KV cache 的对应层视图。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads,
|
|
head_dim,
|
|
scale,
|
|
num_kv_heads,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.scale = scale
|
|
self.num_kv_heads = num_kv_heads
|
|
self.k_cache = self.v_cache = torch.tensor([]) # 占位,由 ModelRunner 分配后绑定
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
|
context = get_context()
|
|
k_cache, v_cache = self.k_cache, self.v_cache
|
|
if k_cache.numel() and v_cache.numel():
|
|
# 将新计算的 K/V 写入 KV cache
|
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
|
if context.is_prefill:
|
|
if context.block_tables is not None: # 前缀缓存命中:从 KV cache 读取 K/V
|
|
k, v = k_cache, v_cache
|
|
# Flash Attention 变长版:支持不同长度的序列在同一批次中计算
|
|
o = flash_attn_varlen_func(q, k, v,
|
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
|
else: # Decode 阶段:从 KV cache 中读取所有历史 K/V
|
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
|
softmax_scale=self.scale, causal=True)
|
|
return o
|