From f4442de1c4fd8623afe2a6571989deef45cfcaa6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:34:26 +0800 Subject: [PATCH] Add reflect=0.1 to invocations of random_clamp() --- .../pruned_transducer_stateless7/conformer.py | 3 ++- .../ASR/pruned_transducer_stateless7/model.py | 9 ++++++--- .../pruned_transducer_stateless7/scaling.py | 20 ++++++++++++++----- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9f6180eab..3edfd2595 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1116,7 +1116,8 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = random_clamp(attn_output_weights, min=-attn_weights_max, max=attn_weights_max, - prob=0.5) + prob=0.5, + reflect=0.1) # attn_output_weights: (batch, head, time1, time2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 7a5d037fc..0c1bd9551 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -143,8 +143,10 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out) if self.training: - lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5) - am = random_clamp(am, min=-5.0, max=5.0, prob=0.5) + lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5, + reflect=0.1) + am = random_clamp(am, min=-5.0, max=5.0, prob=0.5, + reflect=0.1) with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( @@ -182,7 +184,8 @@ class Transducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned, project_input=False) if self.training: - logits = random_clamp(logits, -8.0, 2.0, prob=0.5) + logits = random_clamp(logits, -8.0, 2.0, prob=0.5, + reflect=0.1) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 87626b780..19e8e6fa8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -165,24 +165,34 @@ class RandomClampFunction(torch.autograd.Function): x: Tensor, min: Optional[float], max: Optional[float], - prob: float) -> Tensor: + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) if x.requires_grad: ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: is_same, = ctx.saved_tensors - return ans_grad * is_same.to(ans_grad.dtype), None, None, None + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None, max: Optional[float] = None, - prob: float = 0.5): - return RandomClampFunction.apply(x, min, max, prob) + prob: float = 0.5, + reflect: float = 0.0): + return RandomClampFunction.apply(x, min, max, prob, reflect)