diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0c1bd9551..ee88a9159 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,12 +19,10 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface -from scaling import random_clamp from icefall.utils import add_sos - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -142,12 +140,6 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - if self.training: - lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5, - reflect=0.1) - am = random_clamp(am, min=-5.0, max=5.0, prob=0.5, - reflect=0.1) - with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -183,10 +175,6 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - if self.training: - logits = random_clamp(logits, -8.0, 2.0, prob=0.5, - reflect=0.1) - with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 773bab4e9..280153137 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -175,7 +175,6 @@ class RandomClampFunction(torch.autograd.Function): ctx.reflect = reflect if reflect != 0.0: ans = ans * (1.0 + reflect) - (x * reflect) - return ans @staticmethod @@ -185,7 +184,7 @@ class RandomClampFunction(torch.autograd.Function): reflect = ctx.reflect if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None + return x_grad, None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None,