diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index ee88a9159..9700e4487 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -21,6 +21,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from icefall.utils import add_sos +from scaling import ActivationBalancer class Transducer(nn.Module): @@ -67,6 +68,16 @@ class Transducer(nn.Module): ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + self.lm_balancer = ActivationBalancer(vocab_size, + channel_dim=-1, + min_positive=0, max_positive=1, + min_abs=0, max_abs=50.0) + self.am_balancer = ActivationBalancer(vocab_size, + channel_dim=-1, + min_positive=0, max_positive=1, + min_abs=0, max_abs=50.0) + def forward( self, x: torch.Tensor, @@ -137,8 +148,8 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) + lm = self.lm_balancer(self.simple_lm_proj(decoder_out)) + am = self.am_balancer(self.simple_am_proj(encoder_out)) with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(