Upload to Main
This commit is contained in:
16
deepcore/nets/nets_utils/parallel.py
Normal file
16
deepcore/nets/nets_utils/parallel.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from torch.nn import DataParallel
|
||||
|
||||
|
||||
class MyDataParallel(DataParallel):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.module, name)
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
if name == "no_grad":
|
||||
return setattr(self.module, name, value)
|
||||
return super().__setattr__(name, value)
|
||||
except AttributeError:
|
||||
return setattr(self.module, name, value)
|
||||
Reference in New Issue
Block a user