diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index d4d33303a..7ea1a29ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -18,10 +18,11 @@ import k2 import torch import torch.nn as nn +import random from encoder_interface import EncoderInterface from icefall.utils import add_sos -from scaling import ActivationBalancer +from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -68,18 +69,6 @@ class Transducer(nn.Module): ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) - - self.lm_balancer = ActivationBalancer(vocab_size, - channel_dim=-1, - min_positive=0, max_positive=1, - min_abs=0, max_abs=50.0, - max_factor=0.1) - self.am_balancer = ActivationBalancer(vocab_size, - channel_dim=-1, - min_positive=0, max_positive=1, - min_abs=0, max_abs=30.0, - max_factor=0.1) - def forward( self, x: torch.Tensor, @@ -150,8 +139,13 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - lm = self.lm_balancer(self.simple_lm_proj(decoder_out)) - am = self.am_balancer(self.simple_am_proj(encoder_out)) + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + if self.training and random.random() < 0.25: + lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + if self.training and random.random() < 0.25: + am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(