diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ff86165b7..0b9c7d44a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -430,6 +430,54 @@ class MaxEigLimiterFunction(torch.autograd.Function): return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BasicNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return (x - bias) * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, + store_output_for_backprop: bool) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + ans = x * scales + ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, + scales.detach(), bias.detach(), log_scale.detach()) + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + + + class BasicNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for @@ -450,47 +498,57 @@ class BasicNorm(torch.nn.Module): interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_min: float - eps_max: float + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. """ def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_min: float = -3.0, - eps_max: float = 3.0, + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.eps_min = eps_min - self.eps_max = eps_max + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + self.store_output_for_backprop = store_output_for_backprop def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - eps = self.eps - if self.training: - eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max) - eps = eps.exp() - scales = ( - (torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) - # / (1.0 + eps) - ) ** -0.5 - return x * scales + + if torch.jit.is_scripting(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * + self.log_scale.exp()) + return x * scales + + log_scale = limit_param_value(self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training) + + return BasicNormFunction.apply(x, self.bias, log_scale, + self.channel_dim, + self.store_output_for_backprop) @@ -516,7 +574,8 @@ class PositiveConv1d(nn.Conv1d): (N, C, H) i.e. (batch_size, num_channels, height) """ - weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max)) + weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max), + training=self.training) # make absolutely sure there are no negative values. For parameter-averaging-related # reasons, we prefer to also use limit_param_value to make sure the weights stay # positive. @@ -634,7 +693,8 @@ class PositiveConv2d(nn.Conv2d): (N, C, H, W) i.e. (batch_size, num_channels, height, width) """ - weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max)) + weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max), + training=self.training) # make absolutely sure there are no negative values. For parameter-averaging-related # reasons, we prefer to also use limit_param_value to make sure the weights stay # positive. @@ -1156,13 +1216,14 @@ class LimitParamValue(torch.autograd.Function): def limit_param_value(x: Tensor, min: float, max: float, - prob: float = 0.6): + prob: float = 0.6, + training: bool = True): # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. # It's not necessary to do this on every batch: do it only some of the time, # to save a little time. - if random.random() < prob: + if training and random.random() < prob: return LimitParamValue.apply(x, min, max) else: return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c08d66b0b..2222de303 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -453,7 +453,7 @@ class ZipformerEncoderLayer(nn.Module): self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) - self.norm_final = BasicNorm(embed_dim, eps_max=4.0) + self.norm_final = BasicNorm(embed_dim) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) @@ -868,11 +868,10 @@ class SimpleCombiner(torch.nn.Module): dim2 = src2.shape[-1] - weight1 = self.weight1 - if self.training: - weight1 = limit_param_value(weight1, - min=self.min_weight[0], - max=1.0-self.min_weight[1]) + weight1 = limit_param_value(self.weight1, + min=self.min_weight[0], + max=1.0-self.min_weight[1], + training=self.training) src1_dim = src1.shape[-1] src2_dim = src2.shape[-1] @@ -1896,7 +1895,8 @@ class Conv2dSubsampling(nn.Module): x = x * limit_param_value(self.scale, min=float(self.scale_min), - max=float(self.scale_max)) + max=float(self.scale_max), + training=self.training) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out(x)