Merge pull request #203 from Anai-Guo/fix-row-parallel-bias-crash
fix RowParallelLinear weight_loader crash when bias is enabled
This commit is contained in:
@@ -141,6 +141,10 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
if param_data.dim() == 1:
|
||||
# bias is not sharded in RowParallelLinear
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
|
||||
Reference in New Issue
Block a user