From 8e15d4312ad66e71ef6cbc3bafb163548d414c4f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:17:29 +0800 Subject: [PATCH 01/12] Add some random clamping in model.py --- .../ASR/pruned_transducer_stateless7/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index ee88a9159..7a5d037fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,10 +19,12 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import random_clamp from icefall.utils import add_sos + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -140,6 +142,10 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) + if self.training: + lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5) + am = random_clamp(am, min=-5.0, max=5.0, prob=0.5) + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -175,6 +181,9 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) + if self.training: + logits = random_clamp(logits, -8.0, 2.0, prob=0.5) + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), From f4442de1c4fd8623afe2a6571989deef45cfcaa6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:34:26 +0800 Subject: [PATCH 02/12] Add reflect=0.1 to invocations of random_clamp() --- .../pruned_transducer_stateless7/conformer.py | 3 ++- .../ASR/pruned_transducer_stateless7/model.py | 9 ++++++--- .../pruned_transducer_stateless7/scaling.py | 20 ++++++++++++++----- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9f6180eab..3edfd2595 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1116,7 +1116,8 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = random_clamp(attn_output_weights, min=-attn_weights_max, max=attn_weights_max, - prob=0.5) + prob=0.5, + reflect=0.1) # attn_output_weights: (batch, head, time1, time2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 7a5d037fc..0c1bd9551 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -143,8 +143,10 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out) if self.training: - lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5) - am = random_clamp(am, min=-5.0, max=5.0, prob=0.5) + lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5, + reflect=0.1) + am = random_clamp(am, min=-5.0, max=5.0, prob=0.5, + reflect=0.1) with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( @@ -182,7 +184,8 @@ class Transducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned, project_input=False) if self.training: - logits = random_clamp(logits, -8.0, 2.0, prob=0.5) + logits = random_clamp(logits, -8.0, 2.0, prob=0.5, + reflect=0.1) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 87626b780..19e8e6fa8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -165,24 +165,34 @@ class RandomClampFunction(torch.autograd.Function): x: Tensor, min: Optional[float], max: Optional[float], - prob: float) -> Tensor: + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) if x.requires_grad: ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: is_same, = ctx.saved_tensors - return ans_grad * is_same.to(ans_grad.dtype), None, None, None + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None, max: Optional[float] = None, - prob: float = 0.5): - return RandomClampFunction.apply(x, min, max, prob) + prob: float = 0.5, + reflect: float = 0.0): + return RandomClampFunction.apply(x, min, max, prob, reflect) From 45c38dec61f50f78e3c058e1870d6b0f1080e1d6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 12:35:17 +0800 Subject: [PATCH 03/12] Remove in_balancer. --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 3edfd2595..c929d2124 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -869,8 +869,6 @@ class RelPositionMultiheadAttention(nn.Module): self.copy_pos_query = Identity() self.copy_query = Identity() - self.in_balancer = ActivationBalancer(3 * attention_dim, - channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) @@ -930,7 +928,7 @@ class RelPositionMultiheadAttention(nn.Module): and S is the sequence length. """ x, weights = self.multi_head_attention_forward( - self.in_balancer(self.in_proj(x)), + self.in_proj(x), self.linear_pos(pos_emb), self.attention_dim, self.num_heads, From d37c159174f12171c64df19d68d6be0b6624294d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 13:41:58 +0800 Subject: [PATCH 04/12] Revert model.py so there are no constraints on the output. --- .../ASR/pruned_transducer_stateless7/model.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0c1bd9551..ee88a9159 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,12 +19,10 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface -from scaling import random_clamp from icefall.utils import add_sos - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -142,12 +140,6 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - if self.training: - lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5, - reflect=0.1) - am = random_clamp(am, min=-5.0, max=5.0, prob=0.5, - reflect=0.1) - with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -183,10 +175,6 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - if self.training: - logits = random_clamp(logits, -8.0, 2.0, prob=0.5, - reflect=0.1) - with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), From 9c54906e63de31e014bf4051bb6ed1612250d897 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 19:16:03 +0800 Subject: [PATCH 05/12] Implement randomized backprop for softmax. --- .../pruned_transducer_stateless7/conformer.py | 7 +- .../pruned_transducer_stateless7/scaling.py | 64 +++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9f6180eab..90acc99f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -34,7 +34,8 @@ from scaling import ( Whiten, Identity, _diag, - random_clamp + random_clamp, + softmax, ) from torch import Tensor, nn @@ -1148,7 +1149,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz * num_heads, seq_len, seq_len ) - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) @@ -1561,7 +1562,7 @@ class AttentionCombine(nn.Module): single_prob_mask) weights = weights.masked_fill(mask, float('-inf')) - weights = weights.softmax(dim=1) + weights = softmax(weights, dim=1) # (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1), ans = torch.matmul(stacked_inputs, weights.unsqueeze(2)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 87626b780..d8a380b9d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -185,6 +185,57 @@ def random_clamp(x: Tensor, return RandomClampFunction.apply(x, min, max, prob) +def random_cast_to_half(x: Tensor, + min_abs: float = 1.0e-03) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_sign = x.sign() + x_abs = x.abs() + is_too_small = (x_abs < min_abs) + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + ctx.save_for_backward(ans) + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + ans, = ctx.saved_tensors + + if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16: + # use a randomized approach to convert to float16 + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return random_cast_to_half(x_grad), None + else: + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, + dim: int): + return SoftmaxFunction.apply(x, dim) + class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod @@ -942,11 +993,24 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:,0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:,0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_softmax() _test_whiten() _test_max_eig() _test_activation_balancer_sign() From 0ad4462632a63727bb3830bfdfbb3e82a2105ed3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 19:27:28 +0800 Subject: [PATCH 06/12] Reduce min_abs from 1e-03 to 1e-04 --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index fed553183..773bab4e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -196,7 +196,7 @@ def random_clamp(x: Tensor, def random_cast_to_half(x: Tensor, - min_abs: float = 1.0e-03) -> Tensor: + min_abs: float = 1.0e-04) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ From a4443efa95b1274df6df0f3dddc24a2e56918c4d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 19:46:17 +0800 Subject: [PATCH 07/12] Add RandomGrad with min_abs=1.0e-04 --- .../pruned_transducer_stateless7/conformer.py | 5 +-- .../pruned_transducer_stateless7/scaling.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 565990708..52aa66bc3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -36,6 +36,7 @@ from scaling import ( _diag, random_clamp, softmax, + RandomGrad, ) from torch import Tensor, nn @@ -304,7 +305,7 @@ class ConformerEncoderLayer(nn.Module): whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01) - + self.random_grad = RandomGrad() def forward( self, @@ -364,7 +365,7 @@ class ConformerEncoderLayer(nn.Module): bypass_scale = bypass_scale.clamp(min=0.1, max=1.0) src = src_orig + delta * self.bypass_scale - return self.whiten(src) + return self.random_grad(self.whiten(src)) class ConformerEncoder(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 773bab4e9..23afb4387 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -211,6 +211,42 @@ def random_cast_to_half(x: Tensor, random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) return torch.where(is_too_small, random_val, x).to(torch.float16) +class RandomGradFunction(torch.autograd.Function): + """ + Does nothing in forward pass; in backward pass, gets rid of very small grads using + randomized approach that preserves expectations (intended to reduce roundoff). + """ + @staticmethod + def forward(ctx, x: Tensor, min_abs: float) -> Tensor: + ctx.min_abs = min_abs + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: + min_abs = ctx.min_abs + if ans_grad.dtype == torch.float16: + return random_cast_to_half(ans_grad.to(torch.float32), + min_abs=ctx.min_abs), None + else: + return ans_grad, None + +class RandomGrad(torch.nn.Module): + """ + Gets rid of very small gradients using an expectation-preserving method, intended to increase + accuracy of training when using amp (automatic mixed precision) + """ + def __init__(self, + min_abs: float = 1.0e-04): + super(RandomGrad, self).__init__() + self.min_abs = min_abs + + def forward(self, + x: Tensor): + if torch.jit.is_scripting() or not self.training: + return x + else: + return RandomGradFunction.apply(x, self.min_abs) + class SoftmaxFunction(torch.autograd.Function): """ From cc15552510243539154ad5a1f074c853d9b2e7dc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 19:53:53 +0800 Subject: [PATCH 08/12] Use full precision to do softmax and store ans. --- .../ASR/pruned_transducer_stateless7/scaling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 773bab4e9..595287a2a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -219,10 +219,13 @@ class SoftmaxFunction(torch.autograd.Function): """ @staticmethod def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - ctx.save_for_backward(ans) - ctx.dim = dim - return ans + with torch.cuda.amp.autocast(enabled=False): + if x.dtype == torch.float16: + x = x.to(torch.float32) + ans = x.softmax(dim=dim) + ctx.save_for_backward(ans) + ctx.dim = dim + return ans @staticmethod def backward(ctx, ans_grad: Tensor): From f6b8f0f63142429e131cce5acbbb32568d3c60d1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Oct 2022 12:49:29 +0800 Subject: [PATCH 09/12] Fix bug in backprop of random_clamp() --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 19e8e6fa8..31c389461 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -175,7 +175,6 @@ class RandomClampFunction(torch.autograd.Function): ctx.reflect = reflect if reflect != 0.0: ans = ans * (1.0 + reflect) - (x * reflect) - return ans @staticmethod @@ -185,7 +184,7 @@ class RandomClampFunction(torch.autograd.Function): reflect = ctx.reflect if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None + return x_grad, None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None, From d137118484034887da485fc4abaf5a488345b488 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Oct 2022 13:23:48 +0800 Subject: [PATCH 10/12] Get the randomized backprop for softmax in autocast mode working. --- .../pruned_transducer_stateless7/scaling.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 280153137..f819bdf7c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -219,30 +219,32 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype ctx.dim = dim return ans @staticmethod def backward(ctx, ans_grad: Tensor): ans, = ctx.saved_tensors - - if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16: - # use a randomized approach to convert to float16 - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return random_cast_to_half(x_grad), None - else: + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) x_grad = ans_grad * ans x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + if ctx.x_dtype == torch.float16: + x_grad = random_cast_to_half(x_grad) + return x_grad, None + def softmax(x: Tensor, dim: int): + logging.info(f"torch.is_autocast_enabled()={torch.is_autocast_enabled()}, x dtype={x.dtype}") return SoftmaxFunction.apply(x, dim) @@ -867,7 +869,6 @@ class DoubleSwish(torch.nn.Module): def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") x = torch.randn(100, 128) @@ -891,7 +892,7 @@ def _test_max_eig(): y.backward(gradient=y_grad) if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) elif proportion > 1.0: assert not torch.allclose(x.grad, y_grad) From 679ba2ee5e0e04bc6bf7ed34aff02cd0b87104b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Oct 2022 13:30:55 +0800 Subject: [PATCH 11/12] Remove debug print --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f819bdf7c..a4e78f25c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -244,7 +244,6 @@ class SoftmaxFunction(torch.autograd.Function): def softmax(x: Tensor, dim: int): - logging.info(f"torch.is_autocast_enabled()={torch.is_autocast_enabled()}, x dtype={x.dtype}") return SoftmaxFunction.apply(x, dim) From 6601035db1139560c74e4306a0de8bf1b1f110d6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Oct 2022 13:53:10 +0800 Subject: [PATCH 12/12] Reduce min_abs from 1.0e-04 to 5.0e-06 --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index fd1091b7f..80154de51 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -195,7 +195,7 @@ def random_clamp(x: Tensor, def random_cast_to_half(x: Tensor, - min_abs: float = 1.0e-04) -> Tensor: + min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ @@ -236,7 +236,7 @@ class RandomGrad(torch.nn.Module): accuracy of training when using amp (automatic mixed precision) """ def __init__(self, - min_abs: float = 1.0e-04): + min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs