diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp index 49b562073..6108dabd4 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py index 673f4ef9a..bef12f8fd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py @@ -26,32 +26,6 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos -class AdapterHook(): - ''' - Implementation of the forward hook to track feature statistics and compute a loss on them. - Will compute mean and variance, and will use l2 as a loss - ''' - def __init__(self, module): - self.hook = module.register_forward_hook(self.hook_fn) - - def hook_fn(self, module, input, output): - # hook co compute deepinversion's feature distribution regularization - nch = input[0].shape[1] - mean = input[0].mean([0, 2, 3]) - var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) - - #forcing mean and variance to match between two distributions - #other ways might work better, i.g. KL divergence - r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm( - module.running_mean.data - mean, 2) - - self.r_feature = r_feature - # must have no output - - def close(self): - self.hook.remove() - - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks"