import os from glob import glob import torch from torch import nn from safetensors import safe_open def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor): """默认权重加载器:直接将加载的权重拷贝到参数中。""" param.data.copy_(loaded_weight) def load_model(model: nn.Module, path: str): """从 HuggingFace safetensors 格式加载模型权重。 支持融合模块的权重加载:本项目将 Q/K/V 投影融合为 qkv_proj, 将 gate/up 投影融合为 gate_up_proj。加载时需要通过 packed_modules_mapping 将原始的独立权重名映射到融合后的模块,并使用自定义的 weight_loader 将权重放置到正确位置。 Args: model: 要加载权重的模型。 path: 模型目录路径,包含 .safetensors 文件。 """ packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) for file in glob(os.path.join(path, "*.safetensors")): with safe_open(file, "pt", "cpu") as f: for weight_name in f.keys(): # 检查是否为融合模块的子权重(如 q_proj, k_proj, gate_proj 等) for k in packed_modules_mapping: if k in weight_name: v, shard_id = packed_modules_mapping[k] # 替换权重名:如 "model.layers.0.self_attn.q_proj.weight" → "...qkv_proj.weight" param_name = weight_name.replace(k, v) param = model.get_parameter(param_name) weight_loader = getattr(param, "weight_loader") weight_loader(param, f.get_tensor(weight_name), shard_id) break else: # 普通权重:直接加载 param = model.get_parameter(weight_name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, f.get_tensor(weight_name))