From 8e15d4312ad66e71ef6cbc3bafb163548d414c4f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:17:29 +0800 Subject: [PATCH] Add some random clamping in model.py --- .../ASR/pruned_transducer_stateless7/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index ee88a9159..7a5d037fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,10 +19,12 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import random_clamp from icefall.utils import add_sos + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -140,6 +142,10 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) + if self.training: + lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5) + am = random_clamp(am, min=-5.0, max=5.0, prob=0.5) + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -175,6 +181,9 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) + if self.training: + logits = random_clamp(logits, -8.0, 2.0, prob=0.5) + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(),