import torch from torch import nn import torch.nn.functional as F class SiluAndMul(nn.Module): """SwiGLU 激活函数:SiLU(gate) * up。 输入是 gate 和 up 拼接的张量,沿最后一维一分为二, 对前半部分应用 SiLU 激活后与后半部分逐元素相乘。 这是 LLaMA/Qwen 系列模型中 MLP 层的标准激活函数。 """ @torch.compile def forward(self, x: torch.Tensor) -> torch.Tensor: x, y = x.chunk(2, -1) return F.silu(x) * y