17 lines
520 B
Python
17 lines
520 B
Python
from torch.nn import DataParallel
|
|
|
|
|
|
class MyDataParallel(DataParallel):
|
|
def __getattr__(self, name):
|
|
try:
|
|
return super().__getattr__(name)
|
|
except AttributeError:
|
|
return getattr(self.module, name)
|
|
def __setattr__(self, name, value):
|
|
try:
|
|
if name == "no_grad":
|
|
return setattr(self.module, name, value)
|
|
return super().__setattr__(name, value)
|
|
except AttributeError:
|
|
return setattr(self.module, name, value)
|