release code
This commit is contained in:
69
Dassl.ProGrad.pytorch/dassl/utils/registry.py
Normal file
69
Dassl.ProGrad.pytorch/dassl/utils/registry.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Modified from https://github.com/facebookresearch/fvcore
|
||||
"""
|
||||
__all__ = ["Registry"]
|
||||
|
||||
|
||||
class Registry:
|
||||
"""A registry providing name -> object mapping, to support
|
||||
custom modules.
|
||||
|
||||
To create a registry (e.g. a backbone registry):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
|
||||
To register an object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone(nn.Module):
|
||||
...
|
||||
|
||||
Or:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
self._obj_map = dict()
|
||||
|
||||
def _do_register(self, name, obj, force=False):
|
||||
if name in self._obj_map and not force:
|
||||
raise KeyError(
|
||||
'An object named "{}" was already '
|
||||
'registered in "{}" registry'.format(name, self._name)
|
||||
)
|
||||
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj=None, force=False):
|
||||
if obj is None:
|
||||
# Used as a decorator
|
||||
def wrapper(fn_or_class):
|
||||
name = fn_or_class.__name__
|
||||
self._do_register(name, fn_or_class, force=force)
|
||||
return fn_or_class
|
||||
|
||||
return wrapper
|
||||
|
||||
# Used as a function call
|
||||
name = obj.__name__
|
||||
self._do_register(name, obj, force=force)
|
||||
|
||||
def get(self, name):
|
||||
if name not in self._obj_map:
|
||||
raise KeyError(
|
||||
'Object name "{}" does not exist '
|
||||
'in "{}" registry'.format(name, self._name)
|
||||
)
|
||||
|
||||
return self._obj_map[name]
|
||||
|
||||
def registered_names(self):
|
||||
return list(self._obj_map.keys())
|
||||
Reference in New Issue
Block a user