release code
This commit is contained in:
50
Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
Normal file
50
Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
import torch.nn as nn
|
||||
|
||||
from .build import HEAD_REGISTRY
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features=2048,
|
||||
hidden_layers=[],
|
||||
activation="relu",
|
||||
bn=True,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(hidden_layers, int):
|
||||
hidden_layers = [hidden_layers]
|
||||
|
||||
assert len(hidden_layers) > 0
|
||||
self.out_features = hidden_layers[-1]
|
||||
|
||||
mlp = []
|
||||
|
||||
if activation == "relu":
|
||||
act_fn = functools.partial(nn.ReLU, inplace=True)
|
||||
elif activation == "leaky_relu":
|
||||
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for hidden_dim in hidden_layers:
|
||||
mlp += [nn.Linear(in_features, hidden_dim)]
|
||||
if bn:
|
||||
mlp += [nn.BatchNorm1d(hidden_dim)]
|
||||
mlp += [act_fn()]
|
||||
if dropout > 0:
|
||||
mlp += [nn.Dropout(dropout)]
|
||||
in_features = hidden_dim
|
||||
|
||||
self.mlp = nn.Sequential(*mlp)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
@HEAD_REGISTRY.register()
|
||||
def mlp(**kwargs):
|
||||
return MLP(**kwargs)
|
||||
Reference in New Issue
Block a user