Files
Rain-Bus ffd2defdfc add Chinese annotations to all source files for learning purposes
Annotated 16 source files covering the full architecture:
engine (scheduler, block manager, model runner), layers (attention,
linear, sampler, etc.), model (qwen3), and utils.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 21:33:15 +08:00

44 lines
2.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from glob import glob
import torch
from torch import nn
from safetensors import safe_open
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
"""默认权重加载器:直接将加载的权重拷贝到参数中。"""
param.data.copy_(loaded_weight)
def load_model(model: nn.Module, path: str):
"""从 HuggingFace safetensors 格式加载模型权重。
支持融合模块的权重加载:本项目将 Q/K/V 投影融合为 qkv_proj
将 gate/up 投影融合为 gate_up_proj。加载时需要通过 packed_modules_mapping
将原始的独立权重名映射到融合后的模块,并使用自定义的 weight_loader
将权重放置到正确位置。
Args:
model: 要加载权重的模型。
path: 模型目录路径,包含 .safetensors 文件。
"""
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
# 检查是否为融合模块的子权重(如 q_proj, k_proj, gate_proj 等)
for k in packed_modules_mapping:
if k in weight_name:
v, shard_id = packed_modules_mapping[k]
# 替换权重名:如 "model.layers.0.self_attn.q_proj.weight" → "...qkv_proj.weight"
param_name = weight_name.replace(k, v)
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, f.get_tensor(weight_name), shard_id)
break
else:
# 普通权重:直接加载
param = model.get_parameter(weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, f.get_tensor(weight_name))