From 025bcc155ddeeb5fd1a1caa9f6d4c381c4bc3d2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 1 Dec 2022 14:20:15 +0800 Subject: [PATCH] Change scale_min of Conv2dSubsampling from .01 to .1; some cosmetic changes/unimportant bugfixes. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index aa94b9f08..3c23dffd8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1384,10 +1384,12 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim) # -> (num_heads, batch_size, seq_len, head_dim) bottleneck = torch.matmul(attn_weights, bottleneck) - bottleneck = self.bottleneck_balancer(bottleneck) - bottleneck = self.bottleneck_activation(bottleneck) + bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim) bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim) + + bottleneck = self.bottleneck_balancer(bottleneck) + bottleneck = self.bottleneck_activation(bottleneck) scales = self.from_bottleneck_proj(bottleneck) x = self.in_proj(x) @@ -1505,6 +1507,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size, num_channels // 2) + x = self.activation(x) # diagnostics only, it's the identity. x = x * s (seq_len, batch_size, embed_dim) = x.shape @@ -1517,7 +1520,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # now x: (num_heads, batch_size, seq_len, head_dim) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - x = self.activation(x) # diagnostics only, it's the identity. x = self.whiten(x) x = self.out_proj(x) return x @@ -1720,7 +1722,7 @@ class Conv2dSubsampling(nn.Module): self.scale = nn.Parameter(torch.ones(out_height * layer3_channels)) self.scale_max = 1.0 - self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.01)) + self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1)) self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())