diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index b21b531d0..398d0236a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -511,7 +511,8 @@ class PositiveConv1d(nn.Conv1d): self.max = max # initialize weight to all positive values. - self.weight[:] = 1.0 / self.weight[0][0].numel() + with torch.no_grad(): + self.weight[:] = 1.0 / self.weight[0][0].numel() def forward(self, input: Tensor) -> Tensor: """ @@ -519,7 +520,6 @@ class PositiveConv1d(nn.Conv1d): (N, C, H) i.e. (batch_size, num_channels, height) """ - weight = self.weight weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max)) # make absolutely sure there are no negative values. For parameter-averaging-related # reasons, we prefer to also use limit_param_value to make sure the weights stay @@ -556,7 +556,6 @@ class ConvNorm1d(torch.nn.Module): def __init__( self, num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. eps: float = 0.25, learn_eps: bool = True, eps_min: float = -3.0, @@ -567,7 +566,6 @@ class ConvNorm1d(torch.nn.Module): ) -> None: super().__init__() self.num_channels = num_channels - self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: @@ -576,7 +574,8 @@ class ConvNorm1d(torch.nn.Module): self.eps_max = eps_max pad = kernel_size // 2 # it has bias=False. - self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad) + self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad, + min=conv_min, max=conv_max) def forward(self, x: Tensor, @@ -598,7 +597,7 @@ class ConvNorm1d(torch.nn.Module): # gradients to allow the parameter to get back into the allowed # region if it happens to exit it. - eps = eps.clamp(min=self.eps_min, max=self.eps_max) + eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max) # sqnorms: (N, 1, T) sqnorms = ( @@ -611,9 +610,9 @@ class ConvNorm1d(torch.nn.Module): sqnorms = sqnorms * counts sqnorms = self.conv(sqnorms) # the clamping is to avoid division by zero for padding frames. - counts = self.conv(counts).clamp(min=0.01) + counts = torch.clamp(self.conv(counts), min=0.01) # scales: (N, 1, T) - scales = (sqnorms / counts + eps.exp()) ** -0.5 + scales = (sqnorms / counts + eps.exp()) ** -0.5 # return x * scales diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 707314b93..30c5b9ef9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1824,7 +1824,7 @@ class Conv2dSubsampling(nn.Module): self.out = nn.Linear(out_height * layer3_channels, out_channels) - self.out_norm = BasicNorm(out_channels, channel_dim=-1) + self.out_norm = ConvNorm1d(out_channels) self.dropout = Dropout2(dropout) @@ -1859,9 +1859,11 @@ class Conv2dSubsampling(nn.Module): min=float(self.scale_min), max=float(self.scale_max)) - x = self.out(x) - x = self.out_norm(x) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out(x) + x = x.transpose(1, 2) # (batch, channels, time) + x = self.out_norm(x) + x = x.transpose(1, 2) # (batch, time=((T-1)//2 - 1))//2, channels) x = self.dropout(x) return x