fix RowParallelLinear weight_loader crash when bias is enabled
When RowParallelLinear has bias=True, the weight_loader crashes with an IndexError because it calls param_data.size(tp_dim) where tp_dim=1, but the bias tensor is 1D and only has dimension 0. The bias in RowParallelLinear is not sharded (all ranks hold the full bias, only rank 0 applies it), so skip the sharding logic for 1D params. Fixes GeeeekExplorer/nano-vllm#125
This commit is contained in:
@@ -141,6 +141,10 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
|
if param_data.dim() == 1:
|
||||||
|
# bias is not sharded in RowParallelLinear
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
return
|
||||||
shard_size = param_data.size(self.tp_dim)
|
shard_size = param_data.size(self.tp_dim)
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user