diff --git a/ANNOTATIONS.md b/ANNOTATIONS.md new file mode 100644 index 0000000..cd594d4 --- /dev/null +++ b/ANNOTATIONS.md @@ -0,0 +1,22 @@ +# Nano-vLLM 注释说明 + +**已添加注释的文件(16个):** + +| 模块 | 文件 | 注释要点 | +|------|------|---------| +| 入口 | `__init__.py` | 项目架构和数据流概览 | +| 配置 | `config.py`, `sampling_params.py` | 每个参数的含义和作用 | +| 引擎 | `sequence.py` | 序列状态机、block_table、序列化机制 | +| 引擎 | `block_manager.py` | 前缀缓存原理、哈希链式计算、引用计数 | +| 引擎 | `scheduler.py` | prefill/decode调度策略、chunked prefill、抢占机制 | +| 引擎 | `model_runner.py` | KV cache分配、CUDA Graph捕获、TP共享内存通信 | +| 引擎 | `llm_engine.py` | 引擎初始化流程、step循环、吞吐量统计 | +| 模型 | `qwen3.py` | Qwen3架构(GQA、Q/K Norm)、融合模块映射 | +| 层 | `attention.py` | Triton kernel写KV cache、Flash Attention两阶段 | +| 层 | `linear.py` | 5种并行线性层(列切/行切/融合QKV/融合gate_up) | +| 层 | `sampler.py` | Gumbel-like采样方法 | +| 层 | `activation.py` | SwiGLU (SiLU * up) | +| 层 | `layernorm.py` | 残差融合RMSNorm | +| 层 | `embed_head.py` | 词表并行Embedding、LM Head前缀优化 | +| 层 | `rotary_embedding.py` | RoPE原理和预计算缓存 | +| 工具 | `context.py`, `loader.py` | 全局上下文机制、safetensors权重加载 | diff --git a/nanovllm/__init__.py b/nanovllm/__init__.py index 551af23..ef3e981 100644 --- a/nanovllm/__init__.py +++ b/nanovllm/__init__.py @@ -1,2 +1,12 @@ +# nano-vllm: 一个轻量级的 vLLM 实现 +# +# 核心架构: +# LLM (入口) → LLMEngine (引擎) → Scheduler (调度器) + ModelRunner (模型运行器) +# ↓ ↓ +# BlockManager (KV缓存管理) Qwen3ForCausalLM (模型) +# +# 数据流: +# 用户请求 → Sequence → Scheduler 调度 → ModelRunner 准备输入 → 模型前向 → 采样 → 返回结果 + from nanovllm.llm import LLM from nanovllm.sampling_params import SamplingParams diff --git a/nanovllm/config.py b/nanovllm/config.py index 7066cbe..81bb9d1 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -5,6 +5,22 @@ from transformers import AutoConfig @dataclass(slots=True) class Config: + """推理引擎的全局配置。 + + Args: + model: HuggingFace 模型的本地路径(必须是一个目录,包含 safetensors 权重文件和 config.json)。 + max_num_batched_tokens: 单次调度(schedule)中允许的最大 token 总数,控制 prefill 阶段的批处理粒度。 + max_num_seqs: 单次调度中允许的最大序列数,限制同时处理的请求数量。 + max_model_len: 模型支持的最大序列长度(prompt + 生成的总长度),会被 HF config 中的 max_position_embeddings 截断。 + gpu_memory_utilization: GPU 显存使用率(0~1),决定 KV cache 可用显存大小。 + tensor_parallel_size: 张量并行度(TP),即使用多少张 GPU 来并行运行模型。 + enforce_eager: 是否强制使用 eager 模式(不使用 CUDA Graph),调试时设为 True。 + hf_config: 从模型目录自动加载的 HuggingFace 配置对象,由 __post_init__ 自动设置。 + eos: End-of-Sequence token 的 ID,由引擎初始化时从 tokenizer 获取,用于判断生成是否结束。 + kvcache_block_size: KV cache 的块大小(token 数),每个 block 存储这么多 token 的 KV 向量。前缀缓存和 KV cache 管理的最小单位。 + num_kvcache_blocks: KV cache 的总块数,-1 表示尚未分配,会在 ModelRunner.allocate_kv_cache() 中根据可用显存自动计算。 + """ + model: str max_num_batched_tokens: int = 16384 max_num_seqs: int = 512 @@ -21,5 +37,7 @@ class Config: assert os.path.isdir(self.model) assert self.kvcache_block_size % 256 == 0 assert 1 <= self.tensor_parallel_size <= 8 + # 从模型目录加载 HuggingFace 配置(如 num_layers, hidden_size, num_heads 等) self.hf_config = AutoConfig.from_pretrained(self.model) + # 确保最大序列长度不超过模型支持的位置编码上限 self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index a48989c..25dfae0 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -6,6 +6,16 @@ from nanovllm.engine.sequence import Sequence class Block: + """KV cache 的物理块。 + + 每个 Block 对应 GPU 显存中一块固定大小的 KV cache 存储,可以存放 block_size 个 token 的 KV 向量。 + + Attributes: + block_id: 物理块的唯一 ID。 + ref_count: 引用计数,有多少个序列正在共享这个块(前缀缓存复用)。 + hash: 该块内容的哈希值,用于前缀缓存查找。初始为 -1 表示未计算哈希。 + token_ids: 该块对应的 token ID 列表,用于验证缓存命中时内容是否一致。 + """ def __init__(self, block_id): self.block_id = block_id @@ -14,16 +24,33 @@ class Block: self.token_ids = [] def update(self, hash: int, token_ids: list[int]): + """更新块的内容哈希和 token_ids。""" self.hash = hash self.token_ids = token_ids def reset(self): + """重置块状态,准备被重新分配。""" self.ref_count = 1 self.hash = -1 self.token_ids = [] class BlockManager: + """管理 KV cache 物理块的分配、释放和前缀缓存。 + + 核心设计: + - 将 KV cache 分成固定大小的物理块(block_size 个 token/块)。 + - 使用哈希表实现前缀缓存(prefix caching):不同序列如果拥有相同的 prompt 前缀, + 可以共享同一组 KV cache 块,避免重复计算。 + - 块通过引用计数(ref_count)管理生命周期,所有引用释放后块变为空闲可复用。 + + Attributes: + block_size: 每个块存储的 token 数量。 + blocks: 所有物理块的列表,索引即 block_id。 + hash_to_block_id: 内容哈希 → 块 ID 的映射,用于前缀缓存查找。 + free_block_ids: 空闲块的 ID 队列。 + used_block_ids: 正在使用的块 ID 集合。 + """ def __init__(self, num_blocks: int, block_size: int): self.block_size = block_size @@ -34,6 +61,15 @@ class BlockManager: @classmethod def compute_hash(cls, token_ids: list[int], prefix: int = -1): + """计算一个块的内容哈希值。 + + 哈希是链式的:每个块的哈希值依赖于前一个块的哈希值(prefix 参数), + 这确保了前缀缓存的一致性——只有完全相同的 token 序列前缀才会产生相同的哈希链。 + + Args: + token_ids: 该块对应的 token ID 列表。 + prefix: 前一个块的哈希值,-1 表示第一个块(无前缀)。 + """ h = xxhash.xxh64() if prefix != -1: h.update(prefix.to_bytes(8, "little")) @@ -41,9 +77,11 @@ class BlockManager: return h.intdigest() def _allocate_block(self) -> int: + """从空闲池中分配一个物理块。""" block_id = self.free_block_ids.popleft() block = self.blocks[block_id] assert block.ref_count == 0 + # 如果该块之前有哈希记录,先从哈希表中移除 if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id: del self.hash_to_block_id[block.hash] block.reset() @@ -51,11 +89,22 @@ class BlockManager: return block_id def _deallocate_block(self, block_id: int): + """释放一个物理块回空闲池。""" assert self.blocks[block_id].ref_count == 0 self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) def can_allocate(self, seq: Sequence) -> int: + """检查是否有足够的空闲块来分配给序列,同时计算前缀缓存命中数。 + + 遍历序列的所有逻辑块,逐个计算哈希并与已有缓存比对: + - 如果哈希匹配且 token 内容一致,则该块可以复用(前缀缓存命中)。 + - 一旦遇到不匹配的块就停止,因为前缀缓存要求连续匹配。 + - 最后检查剩余空闲块是否足够分配未命中部分。 + + Returns: + 前缀缓存命中的块数,如果没有足够的空闲块则返回 -1。 + """ h = -1 num_cached_blocks = 0 num_new_blocks = seq.num_blocks @@ -67,14 +116,25 @@ class BlockManager: break num_cached_blocks += 1 if block_id in self.used_block_ids: + # 块已在使用中(被其他序列共享),不需要额外分配 num_new_blocks -= 1 if len(self.free_block_ids) < num_new_blocks: return -1 return num_cached_blocks def allocate(self, seq: Sequence, num_cached_blocks: int): + """为序列分配 KV cache 物理块。 + + 前缀缓存命中部分:增加已有块的引用计数。 + 未命中部分:从空闲池分配新块。 + + Args: + seq: 要分配块的序列。 + num_cached_blocks: 前缀缓存命中的块数(由 can_allocate 计算)。 + """ assert not seq.block_table h = -1 + # 处理缓存命中部分:复用已有块 for i in range(num_cached_blocks): token_ids = seq.block(i) h = self.compute_hash(token_ids, h) @@ -83,15 +143,22 @@ class BlockManager: if block_id in self.used_block_ids: block.ref_count += 1 else: + # 块在空闲池中但哈希匹配(之前被释放但内容未覆盖),重新激活 block.ref_count = 1 self.free_block_ids.remove(block_id) self.used_block_ids.add(block_id) seq.block_table.append(block_id) + # 分配新块用于未命中部分 for i in range(num_cached_blocks, seq.num_blocks): seq.block_table.append(self._allocate_block()) seq.num_cached_tokens = num_cached_blocks * self.block_size def deallocate(self, seq: Sequence): + """释放序列占用的所有 KV cache 块。 + + 递减每个块的引用计数,引用计数归零的块被回收进空闲池。 + 注意:遍历顺序为逆序,这是为了优先释放最新分配的块(它们在空闲队列尾部)。 + """ for block_id in reversed(seq.block_table): block = self.blocks[block_id] block.ref_count -= 1 @@ -101,16 +168,30 @@ class BlockManager: seq.block_table.clear() def can_append(self, seq: Sequence) -> bool: + """检查 decode 阶段是否能追加一个 token。 + + 当序列当前最后一个块已满时(len(seq) % block_size == 1,即新增 token 是新块的第一个), + 需要分配一个新块。否则只需要写入已有块,不需要额外分配。 + """ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) def may_append(self, seq: Sequence): + """在 decode 阶段,如果需要则追加一个新块。""" if len(seq) % self.block_size == 1: seq.block_table.append(self._allocate_block()) def hash_blocks(self, seq: Sequence): + """更新序列中已完成计算的块的哈希值。 + + 在每次调度步骤完成后调用,将新计算完毕的块的 token 内容和哈希值 + 注册到哈希表中,以便后续序列复用(前缀缓存)。 + + 只处理从 num_cached_tokens 到当前进度的块(即本次新完成的块)。 + """ start = seq.num_cached_tokens // self.block_size end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size if start == end: return + # 继承前一个块的哈希值作为前缀 h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1 for i in range(start, end): block = self.blocks[seq.block_table[i]] diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 3685094..f706e8f 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -13,40 +13,77 @@ from nanovllm.engine.model_runner import ModelRunner class LLMEngine: + """LLM 推理引擎:整个 nano-vllm 的入口和协调者。 + + 职责: + 1. 初始化配置、启动张量并行的子进程。 + 2. 管理请求(添加/调度/生成)。 + 3. 协调调度器和模型运行器之间的交互。 + + 使用方式: + >>> engine = LLMEngine("/path/to/model", enforce_eager=True) + >>> outputs = engine.generate(["Hello"], SamplingParams(max_tokens=64)) + """ def __init__(self, model, **kwargs): + # 过滤出 Config 支持的参数,忽略其他参数 config_fields = {field.name for field in fields(Config)} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} config = Config(model, **config_kwargs) Sequence.block_size = config.kvcache_block_size + + # 启动张量并行的工作进程(rank 1, 2, ...) + # 使用 "spawn" 方式创建子进程,确保 CUDA context 独立 self.ps = [] self.events = [] ctx = mp.get_context("spawn") for i in range(1, config.tensor_parallel_size): - event = ctx.Event() + event = ctx.Event() # 用于通知工作进程有新任务 process = ctx.Process(target=ModelRunner, args=(config, i, event)) process.start() self.ps.append(process) self.events.append(event) + + # 在主进程中初始化模型运行器(rank 0) self.model_runner = ModelRunner(config, 0, self.events) + # 加载 tokenizer(用于文本 ↔ token ID 转换) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) config.eos = self.tokenizer.eos_token_id + # 初始化调度器(此时 num_kvcache_blocks 已由 ModelRunner 分配完毕) self.scheduler = Scheduler(config) + # 注册退出清理函数 atexit.register(self.exit) def exit(self): + """通知所有进程退出并等待。""" self.model_runner.call("exit") del self.model_runner for p in self.ps: p.join() def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): + """添加一个生成请求到调度器的等待队列。 + + Args: + prompt: 可以是字符串(会被 tokenizer 编码为 token ID)或已编码的 token ID 列表。 + sampling_params: 采样参数。 + """ if isinstance(prompt, str): prompt = self.tokenizer.encode(prompt) seq = Sequence(prompt, sampling_params) self.scheduler.add(seq) def step(self): + """执行一个推理步骤。 + + 1. 调度器选择要处理的序列(prefill 或 decode)。 + 2. 模型运行器执行前向推理和采样。 + 3. 调度器后处理结果(更新缓存、检查终止条件)。 + + Returns: + outputs: 已完成序列的 (seq_id, completion_token_ids) 列表。 + num_tokens: prefill 时为处理的 token 总数(正数),decode 时为序列数的负数。 + """ seqs, is_prefill = self.scheduler.schedule() num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs) token_ids = self.model_runner.call("run", seqs, is_prefill) @@ -55,6 +92,7 @@ class LLMEngine: return outputs, num_tokens def is_finished(self): + """检查是否所有请求都已完成。""" return self.scheduler.is_finished() def generate( @@ -63,6 +101,21 @@ class LLMEngine: sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True, ) -> list[str]: + """批量生成文本的入口方法。 + + 工作流程: + 1. 将所有 prompt 添加为请求。 + 2. 循环执行 step() 直到所有序列完成。 + 3. 收集结果并解码为文本。 + + Args: + prompts: prompt 列表,每个元素可以是字符串或 token ID 列表。 + sampling_params: 单个采样参数(应用于所有 prompt)或每个 prompt 对应的采样参数列表。 + use_tqdm: 是否显示进度条。 + + Returns: + 列表,每个元素是 {"text": 解码文本, "token_ids": token ID 列表}。 + """ pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm) if not isinstance(sampling_params, list): sampling_params = [sampling_params] * len(prompts) @@ -73,6 +126,7 @@ class LLMEngine: while not self.is_finished(): t = perf_counter() output, num_tokens = self.step() + # num_tokens > 0 表示 prefill,< 0 表示 decode(取负得到序列数) if num_tokens > 0: prefill_throughput = num_tokens / (perf_counter() - t) else: @@ -85,6 +139,7 @@ class LLMEngine: outputs[seq_id] = token_ids pbar.update(1) pbar.close() + # 按 seq_id 排序输出(保证与输入顺序一致) outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] return outputs diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 71d9883..0d6911e 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -13,6 +13,18 @@ from nanovllm.utils.loader import load_model class ModelRunner: + """模型运行器:负责模型推理、KV cache 管理、CUDA Graph 捕获和张量并行通信。 + + 在张量并行(TP)模式下: + - Rank 0 是主进程,负责采样和与引擎通信。 + - Rank > 0 是工作进程,通过共享内存(SharedMemory)接收指令。 + - 所有进程共享同一个模型和 KV cache 的分片。 + + 生命周期: + 1. 初始化: 加载模型 → warmup → 分配 KV cache → (可选)捕获 CUDA Graph + 2. 推理: 接收序列 → 准备输入 → 运行模型 → 采样 token + 3. 退出: 释放资源 + """ def __init__(self, config: Config, rank: int, event: Event | list[Event]): self.config = config @@ -23,42 +35,54 @@ class ModelRunner: self.rank = rank self.event = event + # 初始化分布式进程组(NCCL 后端),所有 GPU 通过 TCP 通信 dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) + + # 加载模型权重 default_dtype = torch.get_default_dtype() torch.set_default_dtype(hf_config.dtype) torch.set_default_device("cuda") self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() + + # Warmup: 运行一次前向传播以确定模型本身的显存占用 self.warmup_model() + # 根据剩余显存分配 KV cache self.allocate_kv_cache() + # 捕获 CUDA Graph 以加速 decode 阶段的小批量推理 if not self.enforce_eager: self.capture_cudagraph() + torch.set_default_device("cpu") torch.set_default_dtype(default_dtype) + # 张量并行时,rank > 0 的工作进程进入消息循环 if self.world_size > 1: if rank == 0: + # 主进程创建共享内存,工作进程打开它 self.shm = SharedMemory(name="nanovllm", create=True, size=2**20) dist.barrier() else: dist.barrier() self.shm = SharedMemory(name="nanovllm") - self.loop() + self.loop() # 工作进程在此循环,直到收到 exit 指令 def exit(self): + """释放所有资源并退出。""" if self.world_size > 1: self.shm.close() dist.barrier() if self.rank == 0: - self.shm.unlink() + self.shm.unlink() # 只有创建者需要 unlink if not self.enforce_eager: del self.graphs, self.graph_pool torch.cuda.synchronize() dist.destroy_process_group() def loop(self): + """工作进程的主循环:等待主进程指令,执行对应方法。""" while True: method_name, args = self.read_shm() self.call(method_name, *args) @@ -66,14 +90,16 @@ class ModelRunner: break def read_shm(self): + """从共享内存读取主进程发送的方法调用指令。""" assert self.world_size > 1 and self.rank > 0 - self.event.wait() + self.event.wait() # 等待主进程通知 n = int.from_bytes(self.shm.buf[0:4], "little") method_name, *args = pickle.loads(self.shm.buf[4:n+4]) self.event.clear() return method_name, args def write_shm(self, method_name, *args): + """将方法调用指令写入共享内存,通知工作进程。""" assert self.world_size > 1 and self.rank == 0 data = pickle.dumps([method_name, *args]) n = len(data) @@ -83,12 +109,19 @@ class ModelRunner: event.set() def call(self, method_name, *args): + """调用指定方法。TP 模式下主进程先通知工作进程,再本地执行。""" if self.world_size > 1 and self.rank == 0: self.write_shm(method_name, *args) method = getattr(self, method_name, None) return method(*args) def warmup_model(self): + """预热模型:运行一次最大批量的前向传播。 + + 目的是让 PyTorch 分配所有内部缓存(cuBLAS workspace 等), + 然后通过 empty_cache 释放临时显存,这样后续的 peak memory 统计 + 就只包含模型权重,从而准确计算 KV cache 可用空间。 + """ torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len @@ -101,6 +134,15 @@ class ModelRunner: torch.cuda.empty_cache() def allocate_kv_cache(self): + """根据剩余 GPU 显存分配 KV cache。 + + 计算公式: + 可用显存 = 总显存 × gpu_memory_utilization - 非模型占用 + 其中非模型占用 = 已用显存 - peak(模型权重)+ current(当前模型张量) + + KV cache 形状: (2, num_layers, num_blocks, block_size, num_kv_heads/head_dim) + 其中第一维 2 分别对应 K 和 V cache。 + """ config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() @@ -109,10 +151,13 @@ class ModelRunner: current = torch.cuda.memory_stats()["allocated_bytes.all.current"] num_kv_heads = hf_config.num_key_value_heads // self.world_size head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) + # 每个块占用的字节数:2(K+V) × 层数 × block_size × KV头数 × head_dim × dtype字节数 block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes assert config.num_kvcache_blocks > 0 + # 分配 KV cache 张量,形状为 (2, num_layers, num_blocks, block_size, num_kv_heads, head_dim) self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim) + # 将 KV cache 的视图绑定到模型中每个 Attention 层 layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): @@ -121,12 +166,26 @@ class ModelRunner: layer_id += 1 def prepare_block_tables(self, seqs: list[Sequence]): + """将序列的 block_table 列表填充为等长的二维张量,用于 GPU 计算。""" max_len = max(len(seq.block_table) for seq in seqs) + # 用 -1 填充短序列的 block_table block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs] block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) return block_tables def prepare_prefill(self, seqs: list[Sequence]): + """准备 prefill 阶段的模型输入张量。 + + Prefill 阶段需要处理多个序列的 prompt tokens,所有 token 被拼接成一个连续的输入。 + 使用 cu_seqlens(累积序列长度)来标记每个序列的边界,供 flash_attn_varlen 使用。 + + 关键数据: + - input_ids: 所有序列的 token ID 拼接。 + - positions: 每个 token 的位置 ID(考虑前缀缓存偏移)。 + - cu_seqlens_q/k: 查询和键值的累积序列长度。 + - slot_mapping: 将每个 token 映射到 KV cache 中的物理存储位置。 + - block_tables: 前缀缓存命中时需要 block_table 来从 KV cache 读取已缓存的 K/V。 + """ input_ids = [] positions = [] cu_seqlens_q = [0] @@ -136,18 +195,19 @@ class ModelRunner: slot_mapping = [] block_tables = None for seq in seqs: - start = seq.num_cached_tokens + start = seq.num_cached_tokens # 跳过已缓存的 token seqlen_q = seq.num_scheduled_tokens end = start + seqlen_q - seqlen_k = end + seqlen_k = end # KV 的长度是从 0 到 end(包括缓存前缀) input_ids.extend(seq[start:end]) positions.extend(range(start, end)) cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) - if not seq.block_table: # warmup + if not seq.block_table: # warmup 阶段没有 block_table continue + # 计算 slot_mapping:每个 token 对应 KV cache 中的哪个 slot start_block = start // self.block_size end_block = (end + self.block_size - 1) // self.block_size for i in range(start_block, end_block): @@ -159,7 +219,7 @@ class ModelRunner: else: slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size slot_mapping.extend(range(slot_start, slot_end)) - if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache + if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # 前缀缓存命中时,KV 长度 > Q 长度 block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) @@ -170,6 +230,18 @@ class ModelRunner: return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): + """准备 decode 阶段的模型输入张量。 + + Decode 阶段每个序列只处理 1 个 token(最新生成的 token)。 + 模型从 KV cache 中读取之前所有的 K/V 向量来做注意力计算。 + + 关键数据: + - input_ids: 每个序列的最新 token ID。 + - positions: 每个 token 的位置 ID(序列长度 - 1)。 + - slot_mapping: 新 token 的 KV 写入位置。 + - context_lens: 每个序列的上下文总长度。 + - block_tables: KV cache 块映射表。 + """ input_ids = [] positions = [] slot_mapping = [] @@ -178,6 +250,7 @@ class ModelRunner: input_ids.append(seq.last_token) positions.append(len(seq) - 1) context_lens.append(len(seq)) + # slot = 最后一个块的起始位置 + 该块内已有 token 数 - 1 slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) @@ -188,19 +261,30 @@ class ModelRunner: return input_ids, positions def prepare_sample(self, seqs: list[Sequence]): + """准备采样所需的温度参数张量。""" temperatures = [seq.temperature for seq in seqs] temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True) return temperatures @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): + """运行模型前向传播。 + + 对于 decode 阶段的小批量(<=512),使用 CUDA Graph 加速: + CUDA Graph 将整个计算图"录制"下来,后续只需回放即可,避免了 + CPU 端的 kernel launch 开销,对 decode(每个 step 计算量很小)尤为有效。 + """ if is_prefill or self.enforce_eager or input_ids.size(0) > 512: + # 直接运行:prefill(批量动态)、eager 模式、或大批量 decode return self.model.compute_logits(self.model(input_ids, positions)) else: + # 使用 CUDA Graph 回放加速小批量 decode bs = input_ids.size(0) context = get_context() + # 选择 >= bs 的最小预捕获图大小 graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] graph_vars = self.graph_vars + # 将实际输入拷贝到图预分配的固定大小缓冲区中 graph_vars["input_ids"][:bs] = input_ids graph_vars["positions"][:bs] = positions graph_vars["slot_mapping"].fill_(-1) @@ -208,10 +292,16 @@ class ModelRunner: graph_vars["context_lens"].zero_() graph_vars["context_lens"][:bs] = context.context_lens graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables + # 回放图(比重新执行快,跳过了 Python/PyTorch 调度开销) graph.replay() return self.model.compute_logits(graph_vars["outputs"][:bs]) def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: + """执行一次完整的推理步骤:准备输入 → 模型前向 → 采样。 + + Returns: + 采样得到的 token ID 列表(仅 rank 0 返回有效值)。 + """ input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) @@ -221,28 +311,39 @@ class ModelRunner: @torch.inference_mode() def capture_cudagraph(self): + """预捕获不同批量大小的 CUDA Graph。 + + CUDA Graph 要求输入张量的地址不变(同一个内存池),所以需要预分配 + 固定大小的输入缓冲区,并为每个 batch size 录制一个图。 + + 预捕获的 batch size: [1, 2, 4, 8, 16, 32, ..., max_bs] + 运行时选择 >= 实际 batch size 的最小预捕获图。 + """ config = self.config hf_config = config.hf_config max_bs = min(self.config.max_num_seqs, 512) max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + # 预分配固定地址的输入/输出缓冲区 input_ids = torch.zeros(max_bs, dtype=torch.int64) positions = torch.zeros(max_bs, dtype=torch.int64) slot_mapping = torch.zeros(max_bs, dtype=torch.int32) context_lens = torch.zeros(max_bs, dtype=torch.int32) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) outputs = torch.zeros(max_bs, hf_config.hidden_size) + # 要捕获的 batch size 列表 self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) self.graphs = {} self.graph_pool = None + # 逆序捕获:先捕获大的 batch size,共享同一个 graph pool for bs in reversed(self.graph_bs): graph = torch.cuda.CUDAGraph() set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) - outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup 运行 with torch.cuda.graph(graph, self.graph_pool): - outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # 捕获计算图 if self.graph_pool is None: - self.graph_pool = graph.pool() + self.graph_pool = graph.pool() # 所有图共享同一个内存池 self.graphs[bs] = graph torch.cuda.synchronize() reset_context() diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index d15979d..862fb9d 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -6,6 +6,22 @@ from nanovllm.engine.block_manager import BlockManager class Scheduler: + """调度器:决定每个步骤(step)中哪些序列被处理以及处理多少 token。 + + 调度策略采用 vLLM 风格的 prefill-decode 分离调度: + 1. **Prefill 优先**: 每次调度先尝试处理等待中的序列(计算其 prompt 的 KV cache)。 + 2. **Chunked prefill**: 如果一个序列的 prompt 太长,可以分多次调度处理。 + 3. **Decode**: 当没有等待中的序列时(或 prefill token 额度用完),处理正在解码的序列。 + 4. **抢占(Preemption)**: 当 KV cache 空间不足时,将最近的 running 序列抢占回 waiting 队列。 + + 调度约束: + - 总 token 数不超过 max_num_batched_tokens(prefill 阶段)。 + - 总序列数不超过 max_num_seqs。 + + Attributes: + waiting: 等待处理的序列队列(FIFO)。 + running: 正在解码的序列队列。 + """ def __init__(self, config: Config): self.max_num_seqs = config.max_num_seqs @@ -17,50 +33,79 @@ class Scheduler: self.running: deque[Sequence] = deque() def is_finished(self): + """检查是否所有序列都已完成。""" return not self.waiting and not self.running def add(self, seq: Sequence): + """将一个新序列加入等待队列。""" self.waiting.append(seq) def schedule(self) -> tuple[list[Sequence], bool]: + """执行一次调度,返回 (被调度的序列列表, 是否为 prefill 阶段)。 + + 调度逻辑: + 1. Prefill 阶段:从 waiting 队列中选取序列,检查前缀缓存命中情况, + 为每个序列计算需要处理的 token 数量。支持 chunked prefill(长 prompt 分多次处理)。 + 2. Decode 阶段:从 running 队列中选取序列,每个序列处理 1 个 token。 + 如果 KV cache 空间不足,会抢占(preempt)最近加入 running 的序列。 + """ scheduled_seqs = [] num_batched_tokens = 0 - # prefill + # ========== Prefill 阶段 ========== + # 尝试从 waiting 队列中调度序列,计算它们的 prompt KV cache while self.waiting and len(scheduled_seqs) < self.max_num_seqs: seq = self.waiting[0] remaining = self.max_num_batched_tokens - num_batched_tokens if remaining == 0: break + if not seq.block_table: + # 序列尚未分配块,检查前缀缓存和空闲块 num_cached_blocks = self.block_manager.can_allocate(seq) if num_cached_blocks == -1: + # 空闲块不足,停止调度 break + # 需要实际处理的 token 数 = 总 prompt token 数 - 缓存命中的 token 数 num_tokens = seq.num_tokens - num_cached_blocks * self.block_size else: + # 序列已经有块(chunked prefill 的后续分片),只需处理未缓存的 token num_tokens = seq.num_tokens - seq.num_cached_tokens - if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq + + if remaining < num_tokens and scheduled_seqs: + # token 预算不足以处理整个序列,且已有其他序列被调度 + # 注意:第一个序列允许 chunked prefill(remaining < num_tokens 也可以) break + if not seq.block_table: self.block_manager.allocate(seq, num_cached_blocks) + + # 实际调度的 token 数取 min(num_tokens, remaining),实现 chunked prefill seq.num_scheduled_tokens = min(num_tokens, remaining) num_batched_tokens += seq.num_scheduled_tokens + if seq.num_cached_tokens + seq.num_scheduled_tokens == seq.num_tokens: + # 整个 prompt 已全部处理完毕,转移到 running 队列 seq.status = SequenceStatus.RUNNING self.waiting.popleft() self.running.append(seq) + scheduled_seqs.append(seq) if scheduled_seqs: - return scheduled_seqs, True + return scheduled_seqs, True # is_prefill = True - # decode + # ========== Decode 阶段 ========== + # 逐 token 解码,每个序列每次生成 1 个 token while self.running and len(scheduled_seqs) < self.max_num_seqs: seq = self.running.popleft() + # 检查是否有空闲块用于存储新的 KV cache while not self.block_manager.can_append(seq): if self.running: + # 空间不足,抢占最近加入 running 的序列 self.preempt(self.running.pop()) else: + # 连当前序列都要被抢占 self.preempt(seq) break else: @@ -68,21 +113,37 @@ class Scheduler: seq.is_prefill = False self.block_manager.may_append(seq) scheduled_seqs.append(seq) - assert scheduled_seqs + + assert scheduled_seqs, "No sequences to schedule" + # 将调度过的序列放回 running 队列前端(保持顺序) self.running.extendleft(reversed(scheduled_seqs)) - return scheduled_seqs, False + return scheduled_seqs, False # is_prefill = False def preempt(self, seq: Sequence): + """抢占一个序列:释放其 KV cache 并放回等待队列头部。 + + 抢占后序列需要重新做 prefill(重新计算 KV cache)。 + 这是一种牺牲吞吐量来换取 KV cache 空间的策略。 + """ seq.status = SequenceStatus.WAITING seq.is_prefill = True self.block_manager.deallocate(seq) self.waiting.appendleft(seq) def postprocess(self, seqs: list[Sequence], token_ids: list[int], is_prefill: bool): + """在模型推理完成后处理每个序列的结果。 + + 1. 更新块的哈希值(用于前缀缓存)。 + 2. 更新已缓存的 token 计数。 + 3. 对于 prefill:如果整个 prompt 还没处理完,继续等待下一次调度。 + 4. 对于完成的序列(prefill 结束后或 decode 中):追加生成的 token。 + 5. 检查终止条件(EOS 或达到 max_tokens),完成的序列被释放资源。 + """ for seq, token_id in zip(seqs, token_ids): self.block_manager.hash_blocks(seq) seq.num_cached_tokens += seq.num_scheduled_tokens seq.num_scheduled_tokens = 0 + # prefill 阶段如果还没处理完整个 prompt,不采样 token,继续等待 if is_prefill and seq.num_cached_tokens < seq.num_tokens: continue seq.append_token(token_id) diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 4decfce..5f0086c 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -6,12 +6,35 @@ from nanovllm.sampling_params import SamplingParams class SequenceStatus(Enum): + """序列的生命周期状态。 + + WAITING: 等待被调度器选中进入 prefill 阶段。 + RUNNING: 正在解码(decode)阶段,逐 token 生成。 + FINISHED: 生成完成(遇到 EOS 或达到 max_tokens)。 + """ WAITING = auto() RUNNING = auto() FINISHED = auto() class Sequence: + """表示一个推理请求的序列。 + + 每个 Sequence 封装了一个完整的生成请求,从 prompt tokens 到生成的 completion tokens。 + 它是调度器、块管理器和模型运行器之间传递的核心数据结构。 + + 关键概念: + - prompt tokens: 用户输入的原始 token 序列。 + - completion tokens: 模型生成的 token 序列。 + - cached tokens: 已经计算过 KV cache 的 token 数量(用于前缀缓存和 chunked prefill)。 + - scheduled tokens: 当前调度步骤中计划处理的 token 数量。 + - block_table: 该序列在 KV cache 中占用的物理块 ID 列表。 + + 类属性: + block_size: KV cache 块大小,由 Config 在引擎初始化时设置。 + counter: 全局自增计数器,用于为每个序列分配唯一 ID。 + """ + block_size = 256 counter = count() @@ -20,12 +43,12 @@ class Sequence: self.status = SequenceStatus.WAITING self.token_ids = copy(token_ids) self.last_token = token_ids[-1] - self.num_tokens = len(self.token_ids) - self.num_prompt_tokens = len(token_ids) - self.num_cached_tokens = 0 - self.num_scheduled_tokens = 0 - self.is_prefill = True - self.block_table = [] + self.num_tokens = len(self.token_ids) # 当前总 token 数(prompt + 已生成的) + self.num_prompt_tokens = len(token_ids) # prompt 部分的 token 数,生成过程中不变 + self.num_cached_tokens = 0 # 已计算 KV cache 的 token 数,用于前缀缓存命中判断 + self.num_scheduled_tokens = 0 # 当前步骤中被调度处理的 token 数 + self.is_prefill = True # 是否处于 prefill 阶段(首次计算 prompt 的 KV cache) + self.block_table = [] # KV cache 物理块 ID 列表,索引为逻辑块号 self.temperature = sampling_params.temperature self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos @@ -42,42 +65,56 @@ class Sequence: @property def num_completion_tokens(self): + """已生成的 completion token 数量。""" return self.num_tokens - self.num_prompt_tokens @property def prompt_token_ids(self): + """返回 prompt 部分的 token ID 列表。""" return self.token_ids[:self.num_prompt_tokens] @property def completion_token_ids(self): + """返回已生成的 completion token ID 列表。""" return self.token_ids[self.num_prompt_tokens:] @property def num_blocks(self): + """该序列需要的 KV cache 逻辑块数量(向上取整)。""" return (self.num_tokens + self.block_size - 1) // self.block_size @property def last_block_num_tokens(self): + """最后一个块中已使用的 token 数量。""" return self.num_tokens - (self.num_blocks - 1) * self.block_size def block(self, i): + """获取第 i 个逻辑块对应的 token ID 列表。""" assert 0 <= i < self.num_blocks return self.token_ids[i*self.block_size: (i+1)*self.block_size] def append_token(self, token_id: int): + """将一个新生成的 token 追加到序列末尾。""" self.token_ids.append(token_id) self.last_token = token_id self.num_tokens += 1 def __getstate__(self): + """序列化时只保存必要的状态,用于多进程间传递序列数据。 + + prefill 阶段保存完整 token_ids(模型需要读取全部 prompt tokens), + decode 阶段只保存 last_token(模型只需要最新一个 token 的 ID)。 + """ last_state = self.last_token if not self.is_prefill else self.token_ids return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state) def __setstate__(self, state): + """反序列化,恢复序列状态。""" self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state if isinstance(last_state, list): self.token_ids = last_state self.last_token = self.token_ids[-1] else: + # decode 阶段不需要完整的 token_ids,只保存了 last_token self.token_ids = [] self.last_token = last_state diff --git a/nanovllm/layers/activation.py b/nanovllm/layers/activation.py index 06cced3..84ef4ba 100755 --- a/nanovllm/layers/activation.py +++ b/nanovllm/layers/activation.py @@ -4,6 +4,12 @@ import torch.nn.functional as F class SiluAndMul(nn.Module): + """SwiGLU 激活函数:SiLU(gate) * up。 + + 输入是 gate 和 up 拼接的张量,沿最后一维一分为二, + 对前半部分应用 SiLU 激活后与后半部分逐元素相乘。 + 这是 LLaMA/Qwen 系列模型中 MLP 层的标准激活函数。 + """ @torch.compile def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index e416139..fbb386f 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -18,6 +18,12 @@ def store_kvcache_kernel( slot_mapping_ptr, D: tl.constexpr, ): + """Triton kernel:将新计算的 K/V 向量写入 KV cache。 + + 每个 Triton program 处理一个 token 的 K/V 写入。 + slot_mapping 指定了该 token 的 K/V 应该写入 cache 的哪个 slot。 + slot == -1 表示该 token 不需要写入(如 warmup 阶段)。 + """ idx = tl.program_id(0) slot = tl.load(slot_mapping_ptr + idx) if slot == -1: return @@ -31,6 +37,15 @@ def store_kvcache_kernel( def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): + """将新计算的 K/V 向量存储到 KV cache 中。 + + Args: + key: [N, num_heads, head_dim] 新计算的 K 向量。 + value: [N, num_heads, head_dim] 新计算的 V 向量。 + k_cache: KV cache 中 K 的存储区域。 + v_cache: KV cache 中 V 的存储区域。 + slot_mapping: [N] 每个 token 对应的 cache slot 索引。 + """ N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1) == 1 and value.stride(-1) == 1 @@ -41,6 +56,18 @@ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, class Attention(nn.Module): + """注意力层:封装了 Flash Attention 和 KV cache 的交互逻辑。 + + 支持 two 阶段的注意力计算: + - Prefill: 使用 flash_attn_varlen_func(变长序列批量注意力),一次性处理整个 prompt。 + - Decode: 使用 flash_attn_with_kvcache(带 KV cache 的注意力),逐 token 生成。 + + 当存在前缀缓存时(block_tables 不为 None),prefill 阶段直接从 KV cache 读取 + 已缓存的 K/V,而不是用当前计算的 K/V。 + + Attributes: + k_cache, v_cache: 绑定到 ModelRunner 分配的全局 KV cache 的对应层视图。 + """ def __init__( self, @@ -54,22 +81,24 @@ class Attention(nn.Module): self.head_dim = head_dim self.scale = scale self.num_kv_heads = num_kv_heads - self.k_cache = self.v_cache = torch.tensor([]) + self.k_cache = self.v_cache = torch.tensor([]) # 占位,由 ModelRunner 分配后绑定 def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): + # 将新计算的 K/V 写入 KV cache store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: - if context.block_tables is not None: # prefix cache + if context.block_tables is not None: # 前缀缓存命中:从 KV cache 读取 K/V k, v = k_cache, v_cache + # Flash Attention 变长版:支持不同长度的序列在同一批次中计算 o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) - else: # decode + else: # Decode 阶段:从 KV cache 中读取所有历史 K/V o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, - cache_seqlens=context.context_lens, block_table=context.block_tables, + cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) return o diff --git a/nanovllm/layers/embed_head.py b/nanovllm/layers/embed_head.py index 84b3ab5..8b638b1 100644 --- a/nanovllm/layers/embed_head.py +++ b/nanovllm/layers/embed_head.py @@ -7,6 +7,12 @@ from nanovllm.utils.context import get_context class VocabParallelEmbedding(nn.Module): + """词表并行 Embedding:将词表按 TP rank 切分。 + + 每个 rank 只存储词表中属于自己的分片。前向计算时: + 1. 只查找属于自己的 token ID,其他位置输出零。 + 2. 通过 all-reduce 聚合所有 rank 的结果。 + """ def __init__( self, @@ -25,6 +31,7 @@ class VocabParallelEmbedding(nn.Module): self.weight.weight_loader = self.weight_loader def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + """加载属于当前 rank 的词表分片。""" param_data = param.data shard_size = param_data.size(0) start_idx = self.tp_rank * shard_size @@ -33,16 +40,27 @@ class VocabParallelEmbedding(nn.Module): def forward(self, x: torch.Tensor): if self.tp_size > 1: + # 构造 mask:标记哪些 token ID 属于当前 rank 的范围 mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx) + # 将全局 token ID 转为局部索引 x = mask * (x - self.vocab_start_idx) y = F.embedding(x, self.weight) if self.tp_size > 1: + # 非当前 rank 范围的 token 输出清零,然后 all-reduce 求和 y = mask.unsqueeze(1) * y dist.all_reduce(y) return y class ParallelLMHead(VocabParallelEmbedding): + """并行 LM Head:将隐藏状态映射为词表 logits。 + + 与 Embedding 共享权重(如果模型配置了 tie_word_embeddings), + 但前向逻辑不同:使用 F.linear(矩阵乘法)而非 F.embedding(查表)。 + + Prefill 阶段只需要每个序列最后一个 token 的 logits(因为只有最后一个 token 会用于采样), + 所以先用 cu_seqlens 提取最后位置,再做矩阵乘法,减少计算量。 + """ def __init__( self, @@ -56,10 +74,13 @@ class ParallelLMHead(VocabParallelEmbedding): def forward(self, x: torch.Tensor): context = get_context() if context.is_prefill: + # Prefill: 只取每个序列最后一个 token 的隐藏状态用于采样 last_indices = context.cu_seqlens_q[1:] - 1 x = x[last_indices].contiguous() + # 矩阵乘法得到 logits logits = F.linear(x, self.weight) if self.tp_size > 1: + # Gather 所有 rank 的 logits 分片到 rank 0 all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None dist.gather(logits, all_logits, 0) logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None diff --git a/nanovllm/layers/layernorm.py b/nanovllm/layers/layernorm.py index 71bf419..03976e9 100755 --- a/nanovllm/layers/layernorm.py +++ b/nanovllm/layers/layernorm.py @@ -3,6 +3,16 @@ from torch import nn class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization(RMSNorm)。 + + 与 LayerNorm 相比,RMSNorm 不计算均值,只计算均方根,计算量更小。 + 公式: output = x / sqrt(mean(x^2) + eps) * weight + + 提供两个前向路径: + - rms_forward: 标准 RMSNorm。 + - add_rms_forward: 将残差加法融合到 RMSNorm 中(x + residual → RMSNorm), + 减少一次显存读写,是 vLLM 等推理框架的常见优化。 + """ def __init__( self, @@ -31,6 +41,11 @@ class RMSNorm(nn.Module): x: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + """融合残差加法和 RMSNorm。 + + Returns: + (normalized_output, updated_residual): 归一化后的输出和更新后的残差(= x + residual)。 + """ orig_dtype = x.dtype x = x.float().add_(residual.float()) residual = x.to(orig_dtype) diff --git a/nanovllm/layers/linear.py b/nanovllm/layers/linear.py index d9e8158..0092d70 100755 --- a/nanovllm/layers/linear.py +++ b/nanovllm/layers/linear.py @@ -5,11 +5,20 @@ import torch.distributed as dist def divide(numerator, denominator): + """整除断言,确保张量并行时维度能被均匀切分。""" assert numerator % denominator == 0 return numerator // denominator class LinearBase(nn.Module): + """所有并行线性层的基类。 + + Attributes: + tp_dim: 张量并行切分的维度(0=列切分,1=行切分,None=不切分)。 + tp_rank: 当前进程在 TP 组中的 rank。 + tp_size: TP 组的总大小。 + weight: 权重参数,带有 weight_loader 方法用于加载预训练权重。 + """ def __init__( self, @@ -35,6 +44,7 @@ class LinearBase(nn.Module): class ReplicatedLinear(LinearBase): + """复制式线性层:所有 TP rank 持有完整的权重副本。用于不需要切分的层。""" def __init__( self, @@ -52,6 +62,11 @@ class ReplicatedLinear(LinearBase): class ColumnParallelLinear(LinearBase): + """列并行线性层:将输出维度按 TP rank 切分。 + + 每个 TP rank 持有输出维度的一个分片。例如输出维度为 4096,TP=2 时每个 rank 持有 2048。 + 常用于 QKV 投影和 FFN 的 gate/up 投影(这些层的输出可以独立计算)。 + """ def __init__( self, @@ -63,6 +78,7 @@ class ColumnParallelLinear(LinearBase): super().__init__(input_size, divide(output_size, tp_size), bias, 0) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + """加载权重时按 tp_rank 切取对应的列分片。""" param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size @@ -74,6 +90,13 @@ class ColumnParallelLinear(LinearBase): class MergedColumnParallelLinear(ColumnParallelLinear): + """融合的列并行线性层:将多个线性层合并为一个矩阵乘法。 + + 典型用途是将 gate_proj 和 up_proj 融合为 gate_up_proj, + 减少 kernel launch 次数,提升计算效率。 + + 权重加载时需要根据 shard_id(子层索引)定位到正确的权重分片位置。 + """ def __init__( self, @@ -85,7 +108,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): super().__init__(input_size, sum(output_sizes), bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): + """根据 shard_id 将权重加载到融合矩阵的正确位置。""" param_data = param.data + # 计算该子层在融合矩阵中的偏移量 shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) @@ -94,6 +119,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear): + """QKV 融合的列并行线性层。 + + 将 Q、K、V 三个投影合并为一个矩阵乘法。 + 权重按 [Q | K | V] 的顺序排列,加载时根据 shard_id("q"/"k"/"v") + 定位到对应的位置。 + + 支持 GQA:Q 的 head 数和 KV 的 head 数可以不同。 + """ def __init__( self, @@ -112,6 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear): super().__init__(hidden_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): + """根据 shard_id ("q"/"k"/"v") 将权重加载到融合矩阵的正确位置。""" param_data = param.data assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": @@ -120,7 +154,7 @@ class QKVParallelLinear(ColumnParallelLinear): elif loaded_shard_id == "k": shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size - else: + else: # "v" shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) @@ -129,6 +163,14 @@ class QKVParallelLinear(ColumnParallelLinear): class RowParallelLinear(LinearBase): + """行并行线性层:将输入维度按 TP rank 切分。 + + 每个 TP rank 持有输入维度的一个分片。前向计算后需要 all-reduce + 将所有 rank 的结果求和,得到完整的输出。 + 常用于 O 投影和 FFN 的 down 投影(这些层的输出需要跨 rank 聚合)。 + + 偏置项只在 rank 0 添加,避免 all-reduce 后重复加 bias。 + """ def __init__( self, @@ -140,8 +182,10 @@ class RowParallelLinear(LinearBase): super().__init__(divide(input_size, tp_size), output_size, bias, 1) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + """加载权重时按 tp_rank 切取对应的行分片。""" param_data = param.data if param_data.ndim == 1: + # bias 不切分,每个 rank 持有完整副本 param_data.copy_(loaded_weight) return shard_size = param_data.size(self.tp_dim) @@ -150,7 +194,8 @@ class RowParallelLinear(LinearBase): param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: + # 只有 rank 0 加 bias,避免 all-reduce 后重复 y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None) if self.tp_size > 1: - dist.all_reduce(y) + dist.all_reduce(y) # 跨 rank 求和,得到完整输出 return y diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py index f4747a4..df34af3 100644 --- a/nanovllm/layers/rotary_embedding.py +++ b/nanovllm/layers/rotary_embedding.py @@ -8,6 +8,14 @@ def apply_rotary_emb( cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: + """应用旋转位置编码(RoPE)。 + + 将向量沿最后一维分成两半 (x1, x2),然后做旋转变换: + y1 = x1 * cos - x2 * sin + y2 = x2 * cos + x1 * sin + 这等价于在二维平面上将每个相邻的 (x1, x2) 对旋转 theta 角度, + 其中 theta = position / (base^(2i/d))。 + """ x1, x2 = torch.chunk(x.float(), 2, dim=-1) y1 = x1 * cos - x2 * sin y2 = x2 * cos + x1 * sin @@ -15,6 +23,15 @@ def apply_rotary_emb( class RotaryEmbedding(nn.Module): + """旋转位置编码(Rotary Position Embedding, RoPE)。 + + RoPE 通过旋转矩阵编码位置信息,使得注意力计算中 + 内积只依赖相对位置(q_i · k_j 只与 i-j 有关), + 从而天然支持外推到更长序列。 + + 预计算所有位置的 cos 和 sin 值并缓存,避免重复计算。 + 缓存形状: [max_position_embeddings, 1, head_dim],中间维是 num_heads 的广播维度。 + """ def __init__( self, @@ -26,11 +43,14 @@ class RotaryEmbedding(nn.Module): super().__init__() self.head_size = head_size assert rotary_dim == head_size + # 计算逆频率: 1 / (base^(2i/d)), i = 0, 1, ..., d/2-1 inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + # 计算所有位置的频率: pos * inv_freq t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() + # 拼接 cos 和 sin,形状 [max_pos, 1, head_dim] cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1) self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -41,6 +61,7 @@ class RotaryEmbedding(nn.Module): query: torch.Tensor, key: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + """根据位置索引查找 cos/sin 并应用到 Q 和 K。""" cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) query = apply_rotary_emb(query, cos, sin) @@ -55,5 +76,9 @@ def get_rope( max_position: int, base: float, ): + """获取 RotaryEmbedding 的单例实例。 + + 使用 lru_cache 确保相同参数只创建一次。 + """ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) return rotary_emb diff --git a/nanovllm/layers/sampler.py b/nanovllm/layers/sampler.py index 41838ac..d7ca77d 100644 --- a/nanovllm/layers/sampler.py +++ b/nanovllm/layers/sampler.py @@ -3,10 +3,21 @@ from torch import nn class Sampler(nn.Module): + """采样器:将 logits 转换为 token ID。 + + 使用 Gumbel-like 采样方法(而非标准的 top-k/top-p): + 1. 将 logits 除以温度(temperature)。 + 2. 计算 softmax 得到概率分布。 + 3. 用指数分布噪声扰动概率,取 argmax。 + + 这种方法等价于从 softmax(logits/temperature) 分布中采样, + 但避免了逐元素随机选择的低效操作,全部用张量运算实现。 + """ @torch.compile def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): logits = logits.float().div_(temperatures.unsqueeze(dim=1)) probs = torch.softmax(logits, dim=-1) + # 指数分布噪声采样:probs / Exp(1) 的 argmax 等价于按 probs 概率采样 sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1) return sample_tokens diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index cb147f6..3d76e61 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -12,6 +12,16 @@ from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead class Qwen3Attention(nn.Module): + """Qwen3 的注意力层。 + + 支持 GQA(Grouped Query Attention):num_kv_heads 可以小于 num_heads, + 多个 query head 共享同一组 KV head,减少 KV cache 的显存占用。 + + 当模型配置中没有 qkv_bias 时(Qwen3 默认无 bias),会在 Q 和 K 投影后 + 添加 RMSNorm(q_norm, k_norm),这是 Qwen3 的特殊设计。 + + 支持张量并行(TP):QKV 投影按列切分,O 投影按行切分。 + """ def __init__( self, @@ -29,16 +39,17 @@ class Qwen3Attention(nn.Module): tp_size = dist.get_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + self.num_heads = self.total_num_heads // tp_size # 当前 TP rank 拥有的 query head 数 self.total_num_kv_heads = num_kv_heads assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size + self.num_kv_heads = self.total_num_kv_heads // tp_size # 当前 TP rank 拥有的 KV head 数 self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim ** -0.5 # 注意力缩放因子: 1/sqrt(d_k) self.qkv_bias = qkv_bias + # QKV 融合投影:将 hidden_states 投影为 Q、K、V,按 TP 切分 self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -46,6 +57,7 @@ class Qwen3Attention(nn.Module): self.total_num_kv_heads, bias=qkv_bias, ) + # 输出投影:将注意力输出映射回 hidden_size,按 TP 行切分 self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, @@ -53,6 +65,7 @@ class Qwen3Attention(nn.Module): ) if isinstance(rope_scaling, dict): rope_theta = rope_scaling.get("rope_theta", rope_theta) + # 旋转位置编码(RoPE) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -65,6 +78,7 @@ class Qwen3Attention(nn.Module): self.scaling, self.num_kv_heads, ) + # Qwen3 无 bias 时使用 Q/K norm(post-normalization) if not self.qkv_bias: self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -75,20 +89,30 @@ class Qwen3Attention(nn.Module): hidden_states: torch.Tensor, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) + # 将 QKV 融合结果拆分为 Q、K、V q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) if not self.qkv_bias: + # Qwen3 特有:对 Q 和 K 做 RMSNorm q = self.q_norm(q) k = self.k_norm(k) + # 应用旋转位置编码 q, k = self.rotary_emb(positions, q, k) + # 注意力计算(含 KV cache 读写) o = self.attn(q, k, v) + # 输出投影 output = self.o_proj(o.flatten(1, -1)) return output class Qwen3MLP(nn.Module): + """Qwen3 的前馈网络(MLP),使用 SwiGLU 激活函数。 + + 结构: hidden_size → 2×intermediate_size (gate + up) → intermediate_size → hidden_size + 其中 gate 和 up 投影融合为一个矩阵乘法,然后 SiLU(gate) * up。 + """ def __init__( self, @@ -97,9 +121,10 @@ class Qwen3MLP(nn.Module): hidden_act: str, ) -> None: super().__init__() + # 融合 gate_proj 和 up_proj 为一个矩阵,减少一次 kernel launch self.gate_up_proj = MergedColumnParallelLinear( hidden_size, - [intermediate_size] * 2, + [intermediate_size] * 2, # gate 和 up 各输出 intermediate_size bias=False, ) self.down_proj = RowParallelLinear( @@ -108,7 +133,7 @@ class Qwen3MLP(nn.Module): bias=False, ) assert hidden_act == "silu" - self.act_fn = SiluAndMul() + self.act_fn = SiluAndMul() # SiLU(gate) * up def forward(self, x): gate_up = self.gate_up_proj(x) @@ -118,6 +143,11 @@ class Qwen3MLP(nn.Module): class Qwen3DecoderLayer(nn.Module): + """Qwen3 的单个 Transformer 解码层。 + + 结构: Input RMSNorm → Self-Attention → Residual → Post-Attention RMSNorm → MLP → Residual + 使用 Pre-Norm 架构(先归一化再进入子层),并将残差连接的计算融合到 RMSNorm 中以节省显存。 + """ def __init__( self, @@ -149,6 +179,7 @@ class Qwen3DecoderLayer(nn.Module): hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: + # 残差连接融合到 RMSNorm 中:residual = hidden_states + residual, output = RMSNorm(residual) if residual is None: hidden_states, residual = self.input_layernorm(hidden_states), hidden_states else: @@ -160,6 +191,7 @@ class Qwen3DecoderLayer(nn.Module): class Qwen3Model(nn.Module): + """Qwen3 的 Transformer 主体:Embedding → N × DecoderLayer → Final RMSNorm。""" def __init__( self, @@ -179,11 +211,19 @@ class Qwen3Model(nn.Module): residual = None for layer in self.layers: hidden_states, residual = layer(positions, hidden_states, residual) + # 最终的 RMSNorm 也融合了残差加法 hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Qwen3ForCausalLM(nn.Module): + """Qwen3 因果语言模型:用于文本生成。 + + 包含映射表,将 HuggingFace 的独立权重名(q_proj, k_proj, v_proj, gate_proj, up_proj) + 映射到本项目融合后的权重名(qkv_proj, gate_up_proj),以便正确加载权重。 + """ + + # 融合模块的映射关系:HF 权重名 → (本项目模块名, shard_id) packed_modules_mapping = { "q_proj": ("qkv_proj", "q"), "k_proj": ("qkv_proj", "k"), @@ -199,6 +239,7 @@ class Qwen3ForCausalLM(nn.Module): super().__init__() self.model = Qwen3Model(config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + # 如果模型配置了权重共享(tie),LM Head 和 Embedding 使用同一个权重矩阵 if config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data @@ -213,4 +254,5 @@ class Qwen3ForCausalLM(nn.Module): self, hidden_states: torch.Tensor, ) -> torch.Tensor: + """将最后一层隐藏状态通过 LM Head 转换为词表 logits。""" return self.lm_head(hidden_states) diff --git a/nanovllm/sampling_params.py b/nanovllm/sampling_params.py index 3e46598..141d50c 100644 --- a/nanovllm/sampling_params.py +++ b/nanovllm/sampling_params.py @@ -3,6 +3,16 @@ from dataclasses import dataclass @dataclass(slots=True) class SamplingParams: + """生成采样的参数配置。 + + Args: + temperature: 采样温度,控制输出的随机性。值越大越随机,越接近 0 越确定。 + 注意:本项目不支持 temperature=0(贪心解码),必须大于 1e-10。 + max_tokens: 单个请求最大生成的 token 数量。 + ignore_eos: 是否忽略 EOS token。设为 True 时即使遇到结束符也继续生成,直到 max_tokens 耗尽。 + 基准测试中用于确保每个请求都生成固定数量的 token。 + """ + temperature: float = 1.0 max_tokens: int = 64 ignore_eos: bool = False diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 3b02a1d..19a2de5 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -4,6 +4,21 @@ import torch @dataclass(slots=True) class Context: + """全局上下文:存储当前推理步骤的注意力相关元数据。 + + 这个对象在每次推理步骤开始时被 ModelRunner 设置,在模型的前向传播中 + 被 Attention 层读取。它是一个全局单例,避免了通过函数参数层层传递。 + + Attributes: + is_prefill: 当前是否为 prefill 阶段。 + cu_seqlens_q: 查询的累积序列长度(prefill 阶段使用),标记每个序列的边界。 + cu_seqlens_k: 键值的累积序列长度(prefill 阶段使用),可能与 cu_seqlens_q 不同(前缀缓存)。 + max_seqlen_q: 批次中最长的查询序列长度(prefill 使用)。 + max_seqlen_k: 批次中最长的键值序列长度(prefill 使用)。 + slot_mapping: 每个 token 在 KV cache 中的存储位置索引(用于写入新的 K/V)。 + context_lens: 每个序列的上下文总长度(decode 阶段使用)。 + block_tables: KV cache 块映射表,将逻辑块映射到物理块(decode 和前缀缓存使用)。 + """ is_prefill: bool = False cu_seqlens_q: torch.Tensor | None = None cu_seqlens_k: torch.Tensor | None = None @@ -16,12 +31,15 @@ class Context: _CONTEXT = Context() def get_context(): + """获取当前全局上下文。""" return _CONTEXT def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None): + """设置当前推理步骤的全局上下文。""" global _CONTEXT _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) def reset_context(): + """重置全局上下文(推理步骤结束后调用)。""" global _CONTEXT _CONTEXT = Context() diff --git a/nanovllm/utils/loader.py b/nanovllm/utils/loader.py index 4ef8040..d263a77 100644 --- a/nanovllm/utils/loader.py +++ b/nanovllm/utils/loader.py @@ -6,23 +6,38 @@ from safetensors import safe_open def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor): + """默认权重加载器:直接将加载的权重拷贝到参数中。""" param.data.copy_(loaded_weight) def load_model(model: nn.Module, path: str): + """从 HuggingFace safetensors 格式加载模型权重。 + + 支持融合模块的权重加载:本项目将 Q/K/V 投影融合为 qkv_proj, + 将 gate/up 投影融合为 gate_up_proj。加载时需要通过 packed_modules_mapping + 将原始的独立权重名映射到融合后的模块,并使用自定义的 weight_loader + 将权重放置到正确位置。 + + Args: + model: 要加载权重的模型。 + path: 模型目录路径,包含 .safetensors 文件。 + """ packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) for file in glob(os.path.join(path, "*.safetensors")): with safe_open(file, "pt", "cpu") as f: for weight_name in f.keys(): + # 检查是否为融合模块的子权重(如 q_proj, k_proj, gate_proj 等) for k in packed_modules_mapping: if k in weight_name: v, shard_id = packed_modules_mapping[k] + # 替换权重名:如 "model.layers.0.self_attn.q_proj.weight" → "...qkv_proj.weight" param_name = weight_name.replace(k, v) param = model.get_parameter(param_name) weight_loader = getattr(param, "weight_loader") weight_loader(param, f.get_tensor(weight_name), shard_id) break else: + # 普通权重:直接加载 param = model.get_parameter(weight_name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, f.get_tensor(weight_name))