From 244633660d8cbeef7d048a1c6c1b13b901007547 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Dec 2022 20:28:03 +0800 Subject: [PATCH] Implement ConvNorm2d and use it in frontend after convnext --- .../pruned_transducer_stateless7/scaling.py | 114 ++++++++++++++++++ .../pruned_transducer_stateless7/zipformer.py | 9 +- 2 files changed, 119 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 398d0236a..b450999ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -616,6 +616,120 @@ class ConvNorm1d(torch.nn.Module): return x * scales +class PositiveConv2d(nn.Conv2d): + """ + A modified form of nn.Conv2d where the weight parameters are constrained + to be positive and there is no bias. + """ + def __init__( + self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0, + **kwargs): + super().__init__(*args, **kwargs, bias=False) + self.min = min + self.max = max + + # initialize weight to all positive values. + with torch.no_grad(): + self.weight[:] = 1.0 / self.weight[0][0].numel() + + def forward(self, input: Tensor) -> Tensor: + """ + Forward function. Input and returned tensor have shape: + (N, C, H, W) + i.e. (batch_size, num_channels, height, width) + """ + weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max)) + # make absolutely sure there are no negative values. For parameter-averaging-related + # reasons, we prefer to also use limit_param_value to make sure the weights stay + # positive. + weight = weight.abs() + + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.bias, self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +class ConvNorm2d(torch.nn.Module): + """ + This is like BasicNorm except the denominator is summed over time using + convolution with positive weights. + + + Args: + num_channels: the number of channels, e.g. 512. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_min: float + eps_max: float + """ + + def __init__( + self, + num_channels: int, + eps: float = 0.25, + learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, + conv_min: float = 0.05, + conv_max: float = 1.0, + kernel_size: Tuple[int, int] = (3, 3), + ) -> None: + super().__init__() + self.num_channels = num_channels + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max + pad = (kernel_size[0] // 2, kernel_size[1] // 2) + # it has bias=False. + self.conv = PositiveConv2d(1, 1, kernel_size=kernel_size, padding=pad, + min=conv_min, max=conv_max) + + + def forward(self, x: Tensor) -> Tensor: + """ + x shape: (N, C, H, W) + """ + assert x.ndim == 4 and x.shape[1] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed + # region if it happens to exit it. + eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max) + + # sqnorms: (N, 1, H, W) + sqnorms = ( + torch.mean(x ** 2, dim=1, keepdim=True) + ) + # 'counts' is a mechanism to correct for edge effects. + # TODO: key-padding mask + + counts = torch.ones_like(sqnorms) + #if src_key_padding_mask is not None: + # counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0) + #sqnorms = sqnorms * counts + sqnorms = self.conv(sqnorms) + # the clamping is to avoid division by zero for padding frames. + counts = torch.clamp(self.conv(counts), min=0.01) + # scales: (N, 1, H, W) + scales = (sqnorms / counts + eps.exp()) ** -0.5 + return x * scales + + def ScaledLinear(*args, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 30c5b9ef9..876d0a0c3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -28,6 +28,7 @@ from scaling import ( ActivationBalancer, BasicNorm, ConvNorm1d, + ConvNorm2d, Dropout2, MaxEig, DoubleSwish, @@ -1792,8 +1793,8 @@ class Conv2dSubsampling(nn.Module): self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), ConvNeXt(layer2_channels), - BasicNorm(layer2_channels, - channel_dim=1)) + ConvNorm2d(layer2_channels, + kernel_size=(15, 7))) # (time, freq) self.conv2 = nn.Sequential( @@ -1812,8 +1813,8 @@ class Conv2dSubsampling(nn.Module): self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), ConvNeXt(layer3_channels), ConvNeXt(layer3_channels), - BasicNorm(layer3_channels, - channel_dim=1)) + ConvNorm2d(layer3_channels, + kernel_size=(15, 5))) # (time, freq) out_height = (((in_channels - 1) // 2) - 1) // 2