minor simplify
This commit is contained in:
@@ -5,9 +5,6 @@ import torch.nn.functional as F
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.compile
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, y = x.chunk(2, -1)
|
||||
|
||||
@@ -141,8 +141,7 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
if param_data.dim() == 1:
|
||||
# bias is not sharded in RowParallelLinear
|
||||
if param_data.ndim == 1:
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
|
||||
@@ -4,9 +4,6 @@ from torch import nn
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.compile
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||
|
||||
Reference in New Issue
Block a user