Upload to Main
This commit is contained in:
18
deepcore/nets/nets_utils/recorder.py
Normal file
18
deepcore/nets/nets_utils/recorder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class EmbeddingRecorder(nn.Module):
|
||||
def __init__(self, record_embedding: bool = False):
|
||||
super().__init__()
|
||||
self.record_embedding = record_embedding
|
||||
|
||||
def forward(self, x):
|
||||
if self.record_embedding:
|
||||
self.embedding = x
|
||||
return x
|
||||
|
||||
def __enter__(self):
|
||||
self.record_embedding = True
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.record_embedding = False
|
||||
Reference in New Issue
Block a user