diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 2a2e2ddf5..3291bc351 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -21,6 +21,7 @@ import warnings from typing import List, Optional, Tuple import logging import torch +import random from encoder_interface import EncoderInterface from scaling import ( ActivationBalancer, @@ -1090,6 +1091,9 @@ class AttentionCombine(nn.Module): weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias + if random.random() < 0.002: + logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}") + if self.training: # random masking.. mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),