From 494139d27a8b34a4c85992869667cdb4b8dd7f20 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Dec 2022 19:15:14 +0800 Subject: [PATCH] Replace BasicNorm of encoder layers with ConvNorm1d --- .../pruned_transducer_stateless7/scaling.py | 119 ++++++++++++++++++ .../pruned_transducer_stateless7/zipformer.py | 11 +- 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 8a9867103..4c3271ac6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -498,6 +498,125 @@ class BasicNorm(torch.nn.Module): +class PositiveConv1d(nn.Conv1d): + """ + A modified form of nn.Conv1d where the weight parameters are constrained + to be positive and there is no bias. + """ + def __init__( + self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0, + **kwargs): + super().__init__(*args, **kwargs, bias=False) + self.min = min + self.max = max + + # initialize weight to all positive values. + self.weight[:] = 1.0 / self.weight[0][0].numel() + + def forward(self, input: Tensor) -> Tensor: + """ + Forward function. Input and returned tensor have shape: + (N, C, H) + i.e. (batch_size, num_channels, height) + """ + weight = self.weight + weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max)) + # make absolutely sure there are no negative values. For parameter-averaging-related + # reasons, we prefer to also use limit_param_value to make sure the weights stay + # positive. + weight = weight.abs() + + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.bias, self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + + +class ConvNorm1d(torch.nn.Module): + """ + This is like BasicNorm except the denominator is summed over time using + convolution with positive weights. + + + Args: + num_channels: the number of channels, e.g. 512. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_min: float + eps_max: float + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, + conv_min: float = 0.1, + conv_max: float = 1.0, + kernel_size: int = 15, + ) -> None: + super().__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max + pad = kernel_size // 2 + # it has bias=False. + self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad) + + + def forward(self, x: Tensor, + src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + x shape: (N, C, T) + + src_key_padding_mask: the mask for the src keys per batch (optional): + (N, T), contains True in masked positions. + + """ + assert x.ndim == 3 and x.shape[1] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed + # region if it happens to exit it. + eps = eps.clamp(min=self.eps_min, max=self.eps_max) + + # sqnorms: (N, 1, T) + sqnorms = ( + torch.mean(x ** 2, dim=1, keepdim=True) + ) + # 'counts' is a mechanism to correct for edge effects. + counts = torch.ones_like(sqnorms) + if src_key_padding_mask is not None: + counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0) + sqnorms = sqnorms * counts + sqnorms = self.conv(sqnorms) + counts = self.conv(counts) + # scales: (N, 1, T) + scales = (sqnorms / counts + eps.exp()) ** -0.5 + return x * scales + + + def ScaledLinear(*args, initial_scale: float = 1.0, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6d078345c..707314b93 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface from scaling import ( ActivationBalancer, BasicNorm, + ConvNorm1d, Dropout2, MaxEig, DoubleSwish, @@ -443,7 +444,7 @@ class ZipformerEncoderLayer(nn.Module): self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) - self.norm_final = BasicNorm(embed_dim) + self.norm_final = ConvNorm1d(embed_dim) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) @@ -555,8 +556,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward2(src) - - src = self.norm_final(self.balancer(src)) + src = self.balancer(src) + src = src.permute(1, 2, 0) # (batch, channels, time) + src = self.norm_final(src, src_key_padding_mask) + src = src.permute(2, 0, 1) # (time, batch, channels) delta = src - src_orig @@ -1606,7 +1609,7 @@ class ConvolutionModule(nn.Module): Args: x: Input tensor (#time, batch, channels). src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. + (batch, #time), contains True in masked positions. Returns: Tensor: Output tensor (#time, batch, channels).