from dataclasses import dataclass import torch @dataclass(slots=True) class Context: """全局上下文:存储当前推理步骤的注意力相关元数据。 这个对象在每次推理步骤开始时被 ModelRunner 设置,在模型的前向传播中 被 Attention 层读取。它是一个全局单例,避免了通过函数参数层层传递。 Attributes: is_prefill: 当前是否为 prefill 阶段。 cu_seqlens_q: 查询的累积序列长度(prefill 阶段使用),标记每个序列的边界。 cu_seqlens_k: 键值的累积序列长度(prefill 阶段使用),可能与 cu_seqlens_q 不同(前缀缓存)。 max_seqlen_q: 批次中最长的查询序列长度(prefill 使用)。 max_seqlen_k: 批次中最长的键值序列长度(prefill 使用)。 slot_mapping: 每个 token 在 KV cache 中的存储位置索引(用于写入新的 K/V)。 context_lens: 每个序列的上下文总长度(decode 阶段使用)。 block_tables: KV cache 块映射表,将逻辑块映射到物理块(decode 和前缀缓存使用)。 """ is_prefill: bool = False cu_seqlens_q: torch.Tensor | None = None cu_seqlens_k: torch.Tensor | None = None max_seqlen_q: int = 0 max_seqlen_k: int = 0 slot_mapping: torch.Tensor | None = None context_lens: torch.Tensor | None = None block_tables: torch.Tensor | None = None _CONTEXT = Context() def get_context(): """获取当前全局上下文。""" return _CONTEXT def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None): """设置当前推理步骤的全局上下文。""" global _CONTEXT _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) def reset_context(): """重置全局上下文(推理步骤结束后调用)。""" global _CONTEXT _CONTEXT = Context()