Upload to Main
This commit is contained in:
2
deepcore/nets/nets_utils/__init__.py
Normal file
2
deepcore/nets/nets_utils/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .parallel import *
|
||||
from .recorder import *
|
||||
BIN
deepcore/nets/nets_utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
deepcore/nets/nets_utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
deepcore/nets/nets_utils/__pycache__/parallel.cpython-39.pyc
Normal file
BIN
deepcore/nets/nets_utils/__pycache__/parallel.cpython-39.pyc
Normal file
Binary file not shown.
BIN
deepcore/nets/nets_utils/__pycache__/recorder.cpython-39.pyc
Normal file
BIN
deepcore/nets/nets_utils/__pycache__/recorder.cpython-39.pyc
Normal file
Binary file not shown.
16
deepcore/nets/nets_utils/parallel.py
Normal file
16
deepcore/nets/nets_utils/parallel.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from torch.nn import DataParallel
|
||||
|
||||
|
||||
class MyDataParallel(DataParallel):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.module, name)
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
if name == "no_grad":
|
||||
return setattr(self.module, name, value)
|
||||
return super().__setattr__(name, value)
|
||||
except AttributeError:
|
||||
return setattr(self.module, name, value)
|
||||
18
deepcore/nets/nets_utils/recorder.py
Normal file
18
deepcore/nets/nets_utils/recorder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class EmbeddingRecorder(nn.Module):
|
||||
def __init__(self, record_embedding: bool = False):
|
||||
super().__init__()
|
||||
self.record_embedding = record_embedding
|
||||
|
||||
def forward(self, x):
|
||||
if self.record_embedding:
|
||||
self.embedding = x
|
||||
return x
|
||||
|
||||
def __enter__(self):
|
||||
self.record_embedding = True
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.record_embedding = False
|
||||
Reference in New Issue
Block a user