18 lines
452 B
Python
18 lines
452 B
Python
from torch import nn
|
|
|
|
|
|
class EmbeddingRecorder(nn.Module):
|
|
def __init__(self, record_embedding: bool = False):
|
|
super().__init__()
|
|
self.record_embedding = record_embedding
|
|
|
|
def forward(self, x):
|
|
if self.record_embedding:
|
|
self.embedding = x
|
|
return x
|
|
|
|
def __enter__(self):
|
|
self.record_embedding = True
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.record_embedding = False |