minor simplify

This commit is contained in:
GeekExplorer
2026-04-13 22:09:46 +08:00
parent 02a95fdc66
commit 9e8507ef41
6 changed files with 14 additions and 27 deletions
+1 -2
View File
@@ -141,8 +141,7 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
if param_data.dim() == 1:
# bias is not sharded in RowParallelLinear
if param_data.ndim == 1:
param_data.copy_(loaded_weight)
return
shard_size = param_data.size(self.tp_dim)