51 lines
1.2 KiB
Python
51 lines
1.2 KiB
Python
import functools
|
|
import torch.nn as nn
|
|
|
|
from .build import HEAD_REGISTRY
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features=2048,
|
|
hidden_layers=[],
|
|
activation="relu",
|
|
bn=True,
|
|
dropout=0.0,
|
|
):
|
|
super().__init__()
|
|
if isinstance(hidden_layers, int):
|
|
hidden_layers = [hidden_layers]
|
|
|
|
assert len(hidden_layers) > 0
|
|
self.out_features = hidden_layers[-1]
|
|
|
|
mlp = []
|
|
|
|
if activation == "relu":
|
|
act_fn = functools.partial(nn.ReLU, inplace=True)
|
|
elif activation == "leaky_relu":
|
|
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
for hidden_dim in hidden_layers:
|
|
mlp += [nn.Linear(in_features, hidden_dim)]
|
|
if bn:
|
|
mlp += [nn.BatchNorm1d(hidden_dim)]
|
|
mlp += [act_fn()]
|
|
if dropout > 0:
|
|
mlp += [nn.Dropout(dropout)]
|
|
in_features = hidden_dim
|
|
|
|
self.mlp = nn.Sequential(*mlp)
|
|
|
|
def forward(self, x):
|
|
return self.mlp(x)
|
|
|
|
|
|
@HEAD_REGISTRY.register()
|
|
def mlp(**kwargs):
|
|
return MLP(**kwargs)
|