From 8e15d4312ad66e71ef6cbc3bafb163548d414c4f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:17:29 +0800 Subject: [PATCH 1/3] Add some random clamping in model.py --- .../ASR/pruned_transducer_stateless7/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index ee88a9159..7a5d037fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,10 +19,12 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import random_clamp from icefall.utils import add_sos + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -140,6 +142,10 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) + if self.training: + lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5) + am = random_clamp(am, min=-5.0, max=5.0, prob=0.5) + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -175,6 +181,9 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) + if self.training: + logits = random_clamp(logits, -8.0, 2.0, prob=0.5) + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), From f4442de1c4fd8623afe2a6571989deef45cfcaa6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:34:26 +0800 Subject: [PATCH 2/3] Add reflect=0.1 to invocations of random_clamp() --- .../pruned_transducer_stateless7/conformer.py | 3 ++- .../ASR/pruned_transducer_stateless7/model.py | 9 ++++++--- .../pruned_transducer_stateless7/scaling.py | 20 ++++++++++++++----- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9f6180eab..3edfd2595 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1116,7 +1116,8 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = random_clamp(attn_output_weights, min=-attn_weights_max, max=attn_weights_max, - prob=0.5) + prob=0.5, + reflect=0.1) # attn_output_weights: (batch, head, time1, time2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 7a5d037fc..0c1bd9551 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -143,8 +143,10 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out) if self.training: - lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5) - am = random_clamp(am, min=-5.0, max=5.0, prob=0.5) + lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5, + reflect=0.1) + am = random_clamp(am, min=-5.0, max=5.0, prob=0.5, + reflect=0.1) with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( @@ -182,7 +184,8 @@ class Transducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned, project_input=False) if self.training: - logits = random_clamp(logits, -8.0, 2.0, prob=0.5) + logits = random_clamp(logits, -8.0, 2.0, prob=0.5, + reflect=0.1) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 87626b780..19e8e6fa8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -165,24 +165,34 @@ class RandomClampFunction(torch.autograd.Function): x: Tensor, min: Optional[float], max: Optional[float], - prob: float) -> Tensor: + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) if x.requires_grad: ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: is_same, = ctx.saved_tensors - return ans_grad * is_same.to(ans_grad.dtype), None, None, None + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None, max: Optional[float] = None, - prob: float = 0.5): - return RandomClampFunction.apply(x, min, max, prob) + prob: float = 0.5, + reflect: float = 0.0): + return RandomClampFunction.apply(x, min, max, prob, reflect) From 45c38dec61f50f78e3c058e1870d6b0f1080e1d6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:35:17 +0800 Subject: [PATCH 3/3] Remove in_balancer. --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 3edfd2595..c929d2124 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -869,8 +869,6 @@ class RelPositionMultiheadAttention(nn.Module): self.copy_pos_query = Identity() self.copy_query = Identity() - self.in_balancer = ActivationBalancer(3 * attention_dim, - channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) @@ -930,7 +928,7 @@ class RelPositionMultiheadAttention(nn.Module): and S is the sequence length. """ x, weights = self.multi_head_attention_forward( - self.in_balancer(self.in_proj(x)), + self.in_proj(x), self.linear_pos(pos_emb), self.attention_dim, self.num_heads,