Add some random clamping in model.py

This commit is contained in:
Daniel Povey 2022-10-19 12:17:29 +08:00
parent c3c655d0bd
commit 8e15d4312a

View File

@ -19,10 +19,12 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import random_clamp
from icefall.utils import add_sos from icefall.utils import add_sos
class Transducer(nn.Module): class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks" "Sequence Transduction with Recurrent Neural Networks"
@ -140,6 +142,10 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
if self.training:
lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5)
am = random_clamp(am, min=-5.0, max=5.0, prob=0.5)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
@ -175,6 +181,9 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
if self.training:
logits = random_clamp(logits, -8.0, 2.0, prob=0.5)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),