diff --git a/nanovllm/layers/linear.py b/nanovllm/layers/linear.py index 2a9e8d5..5e54baa 100755 --- a/nanovllm/layers/linear.py +++ b/nanovllm/layers/linear.py @@ -141,6 +141,10 @@ class RowParallelLinear(LinearBase): def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + if param_data.dim() == 1: + # bias is not sharded in RowParallelLinear + param_data.copy_(loaded_weight) + return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)