diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 8d4fbc46c..ef998eb4a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -434,42 +434,44 @@ class MaxEigLimiterFunction(torch.autograd.Function): class BasicNormFunction(torch.autograd.Function): # This computes: - # scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale) + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() # return (x - bias) * scales # (after unsqueezing the bias), but it does it in a memory-efficient way so that # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @staticmethod - def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int, + def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, store_output_for_backprop: bool) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim ctx.store_output_for_backprop = store_output_for_backprop ctx.channel_dim = channel_dim - scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() ans = x * scales ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), eps.detach(), scale.detach()) + scales.detach(), bias.detach(), log_scale.detach()) return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, eps, scale = ctx.saved_tensors + ans_or_x, scales, bias, log_scale = ctx.saved_tensors if ctx.store_output_for_backprop: x = ans_or_x / scales else: x = ans_or_x - with torch.cuda.amp.autocast(enabled=False): - assert eps.dtype != torch.float16 and scale.dtype != torch.float16 - x = x.to(torch.float32).detach() - x.requires_grad = True - eps.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, eps and log_scale. - scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale - ans = x * scales - ans.backward(gradient=ans_grad.to(torch.float32)) - return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None - + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None @@ -489,17 +491,14 @@ class BasicNorm(torch.nn.Module): Args: num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. log_scale: the initial log-scale that we multiply the output by; this is learnable. - eps: the initial epsilon value (not in log space) log_scale_min: FloatLike, minimum allowed value of log_scale log_scale_max: FloatLike, maximum allowed value of log_scale - log_eps_min: FloatLike, minimum allowed value of log_eps - log_eps_max: FloatLike, maximum allowed value of log_eps store_output_for_backprop: only possibly affects memory use; recommend to set to True if you think the output of this module is more likely than the input of this module to be required to be stored for the @@ -511,25 +510,19 @@ class BasicNorm(torch.nn.Module): num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. log_scale: float = 1.0, - eps: float = 0.1, - log_scale_min: FloatLike = -1.5, - log_scale_max: FloatLike = 1.5, - log_eps_min: FloatLike = -3.0, - log_eps_max: FloatLike = 3.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, store_output_for_backprop: bool = False ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.log_eps = nn.Parameter(torch.tensor(eps).log().detach()) + self.bias = nn.Parameter(torch.zeros(num_channels)) self.log_scale_min = log_scale_min self.log_scale_max = log_scale_max - self.log_eps_min = log_eps_min - self.log_eps_max = log_eps_max - self.store_output_for_backprop = store_output_for_backprop def forward(self, x: Tensor) -> Tensor: @@ -538,7 +531,12 @@ class BasicNorm(torch.nn.Module): if torch.jit.is_scripting(): channel_dim = self.channel_dim - scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) * + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * self.log_scale.exp()) return x * scales @@ -546,12 +544,8 @@ class BasicNorm(torch.nn.Module): min=float(self.log_scale_min), max=float(self.log_scale_max), training=self.training) - log_eps = limit_param_value(self.log_eps, - min=float(self.log_eps_min), - max=float(self.log_eps_max), - training=self.training) - return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(), + return BasicNormFunction.apply(x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d560a54f0..bc4a49d29 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1885,8 +1885,7 @@ class Conv2dSubsampling(nn.Module): # max_log_eps=0.0 is to prevent both eps and the output of self.out from # getting large, there is an unnecessary degree of freedom. - self.out_norm = BasicNorm(out_channels, eps=1.0, - log_eps_min=-0.1, log_eps_max=0.0) + self.out_norm = BasicNorm(out_channels) self.dropout = Dropout2(dropout)