diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 90acc99f7..565990708 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -870,8 +870,6 @@ class RelPositionMultiheadAttention(nn.Module): self.copy_pos_query = Identity() self.copy_query = Identity() - self.in_balancer = ActivationBalancer(3 * attention_dim, - channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) @@ -931,7 +929,7 @@ class RelPositionMultiheadAttention(nn.Module): and S is the sequence length. """ x, weights = self.multi_head_attention_forward( - self.in_balancer(self.in_proj(x)), + self.in_proj(x), self.linear_pos(pos_emb), self.attention_dim, self.num_heads, @@ -1117,7 +1115,8 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = random_clamp(attn_output_weights, min=-attn_weights_max, max=attn_weights_max, - prob=0.5) + prob=0.5, + reflect=0.1) # attn_output_weights: (batch, head, time1, time2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index ee88a9159..0c1bd9551 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,10 +19,12 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import random_clamp from icefall.utils import add_sos + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -140,6 +142,12 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) + if self.training: + lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5, + reflect=0.1) + am = random_clamp(am, min=-5.0, max=5.0, prob=0.5, + reflect=0.1) + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -175,6 +183,10 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) + if self.training: + logits = random_clamp(logits, -8.0, 2.0, prob=0.5, + reflect=0.1) + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index d8a380b9d..fed553183 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -165,24 +165,34 @@ class RandomClampFunction(torch.autograd.Function): x: Tensor, min: Optional[float], max: Optional[float], - prob: float) -> Tensor: + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) if x.requires_grad: ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: is_same, = ctx.saved_tensors - return ans_grad * is_same.to(ans_grad.dtype), None, None, None + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None, max: Optional[float] = None, - prob: float = 0.5): - return RandomClampFunction.apply(x, min, max, prob) + prob: float = 0.5, + reflect: float = 0.0): + return RandomClampFunction.apply(x, min, max, prob, reflect) def random_cast_to_half(x: Tensor,