From 851912c5810ea0332a0f1681def0db0ed317f058 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 30 Dec 2022 15:13:00 +0800 Subject: [PATCH 1/4] Remove bias from BasicNorm, add an eps instead. --- .../pruned_transducer_stateless7/scaling.py | 64 +++++++++++-------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 0e10658fc..02102f82c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -434,46 +434,43 @@ class MaxEigLimiterFunction(torch.autograd.Function): class BasicNormFunction(torch.autograd.Function): # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale) # return (x - bias) * scales # (after unsqueezing the bias), but it does it in a memory-efficient way so that # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @staticmethod - @custom_fwd - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, + def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int, store_output_for_backprop: bool) -> Tensor: - assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim ctx.store_output_for_backprop = store_output_for_backprop ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale ans = x * scales ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) + scales.detach(), eps.detach(), scale.detach()) return ans @staticmethod - @custom_bwd def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors + ans_or_x, scales, eps, scale = ctx.saved_tensors if ctx.store_output_for_backprop: x = ans_or_x / scales else: x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None + with torch.cuda.amp.autocast(enabled=False): + assert eps.dtype != torch.float16 and scale.dtype != torch.float16 + x = x.to(torch.float32).detach() + x.requires_grad = True + eps.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, epsand log_scale. + scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale + ans = x * scales + ans.backward(gradient=ans_grad.to(torch.float32)) + return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None @@ -500,8 +497,11 @@ class BasicNorm(torch.nn.Module): {-2, -1, 0, 1, 2, 3}. log_scale: the initial log-scale that we multiply the output by; this is learnable. + eps: the initial epsilon value (not in log space) log_scale_min: FloatLike, minimum allowed value of log_scale log_scale_max: FloatLike, maximum allowed value of log_scale + log_eps_min: FloatLike, minimum allowed value of log_eps + log_eps_max: FloatLike, maximum allowed value of log_eps store_output_for_backprop: only possibly affects memory use; recommend to set to True if you think the output of this module is more likely than the input of this module to be required to be stored for the @@ -512,18 +512,26 @@ class BasicNorm(torch.nn.Module): self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, + log_scale: float = 2.0, + eps: float = 0.25, + log_scale_min: FloatLike = -1.5, + log_scale_max: FloatLike = 1.5, + log_eps_min: FloatLike = -3.0, + log_eps_max: FloatLike = 3.0, store_output_for_backprop: bool = False ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.zeros(num_channels)) + self.log_eps = nn.Parameter(torch.tensor(eps).log().detach()) + self.log_scale_min = log_scale_min self.log_scale_max = log_scale_max + + self.log_eps_min = log_eps_min + self.log_eps_max = log_eps_max + self.store_output_for_backprop = store_output_for_backprop def forward(self, x: Tensor) -> Tensor: @@ -537,7 +545,7 @@ class BasicNorm(torch.nn.Module): bias = self.bias for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * + scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) * self.log_scale.exp()) return x * scales @@ -545,8 +553,12 @@ class BasicNorm(torch.nn.Module): min=float(self.log_scale_min), max=float(self.log_scale_max), training=self.training) + log_eps = limit_param_value(self.log_eps, + min=float(self.log_eps_min), + max=float(self.log_eps_max), + training=self.training) - return BasicNormFunction.apply(x, self.bias, log_scale, + return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(), self.channel_dim, self.store_output_for_backprop) From c4101c78734614b640a17cdb76ad0486c8313349 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 30 Dec 2022 15:31:38 +0800 Subject: [PATCH 2/4] Change initial log_scale from 2 to 0. (was 1.0 in previous expt --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 02102f82c..74db889c6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -512,7 +512,7 @@ class BasicNorm(torch.nn.Module): self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 2.0, + log_scale: float = 0.0, eps: float = 0.25, log_scale_min: FloatLike = -1.5, log_scale_max: FloatLike = 1.5, From d604284f16c9614922f44e6971ff02a813b99677 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 30 Dec 2022 15:48:12 +0800 Subject: [PATCH 3/4] Change initial log_scale back to 1.0 and initial eps to 0.1 --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 74db889c6..ba05fe212 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -512,7 +512,7 @@ class BasicNorm(torch.nn.Module): self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 0.0, + log_scale: float = 1.0, eps: float = 0.25, log_scale_min: FloatLike = -1.5, log_scale_max: FloatLike = 1.5, From 8952b69d4210ea1e5d52f1fa100835b0f07db5b1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 30 Dec 2022 16:28:44 +0800 Subject: [PATCH 4/4] Reduce BasicNorm.eps default from 0.25 to 0.1 --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ba05fe212..8a81e6c3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -513,7 +513,7 @@ class BasicNorm(torch.nn.Module): num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. log_scale: float = 1.0, - eps: float = 0.25, + eps: float = 0.1, log_scale_min: FloatLike = -1.5, log_scale_max: FloatLike = 1.5, log_eps_min: FloatLike = -3.0,