Upload to Main

This commit is contained in:
张菲
2025-10-07 22:42:55 +08:00
commit d3ddab7c5d
218 changed files with 125815 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from .parallel import *
from .recorder import *

View File

@@ -0,0 +1,16 @@
from torch.nn import DataParallel
class MyDataParallel(DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
def __setattr__(self, name, value):
try:
if name == "no_grad":
return setattr(self.module, name, value)
return super().__setattr__(name, value)
except AttributeError:
return setattr(self.module, name, value)

View File

@@ -0,0 +1,18 @@
from torch import nn
class EmbeddingRecorder(nn.Module):
def __init__(self, record_embedding: bool = False):
super().__init__()
self.record_embedding = record_embedding
def forward(self, x):
if self.record_embedding:
self.embedding = x
return x
def __enter__(self):
self.record_embedding = True
def __exit__(self, exc_type, exc_val, exc_tb):
self.record_embedding = False