From d37c159174f12171c64df19d68d6be0b6624294d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 13:41:58 +0800 Subject: [PATCH 1/3] Revert model.py so there are no constraints on the output. --- .../ASR/pruned_transducer_stateless7/model.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0c1bd9551..ee88a9159 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,12 +19,10 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface -from scaling import random_clamp from icefall.utils import add_sos - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -142,12 +140,6 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - if self.training: - lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5, - reflect=0.1) - am = random_clamp(am, min=-5.0, max=5.0, prob=0.5, - reflect=0.1) - with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -183,10 +175,6 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - if self.training: - logits = random_clamp(logits, -8.0, 2.0, prob=0.5, - reflect=0.1) - with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), From f6b8f0f63142429e131cce5acbbb32568d3c60d1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Oct 2022 12:49:29 +0800 Subject: [PATCH 2/3] Fix bug in backprop of random_clamp() --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 19e8e6fa8..31c389461 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -175,7 +175,6 @@ class RandomClampFunction(torch.autograd.Function): ctx.reflect = reflect if reflect != 0.0: ans = ans * (1.0 + reflect) - (x * reflect) - return ans @staticmethod @@ -185,7 +184,7 @@ class RandomClampFunction(torch.autograd.Function): reflect = ctx.reflect if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None + return x_grad, None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None, From d137118484034887da485fc4abaf5a488345b488 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Oct 2022 13:23:48 +0800 Subject: [PATCH 3/3] Get the randomized backprop for softmax in autocast mode working. --- .../pruned_transducer_stateless7/scaling.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 280153137..f819bdf7c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -219,30 +219,32 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype ctx.dim = dim return ans @staticmethod def backward(ctx, ans_grad: Tensor): ans, = ctx.saved_tensors - - if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16: - # use a randomized approach to convert to float16 - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return random_cast_to_half(x_grad), None - else: + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) x_grad = ans_grad * ans x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + if ctx.x_dtype == torch.float16: + x_grad = random_cast_to_half(x_grad) + return x_grad, None + def softmax(x: Tensor, dim: int): + logging.info(f"torch.is_autocast_enabled()={torch.is_autocast_enabled()}, x dtype={x.dtype}") return SoftmaxFunction.apply(x, dim) @@ -867,7 +869,6 @@ class DoubleSwish(torch.nn.Module): def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") x = torch.randn(100, 128) @@ -891,7 +892,7 @@ def _test_max_eig(): y.backward(gradient=y_grad) if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) elif proportion > 1.0: assert not torch.allclose(x.grad, y_grad)