diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9f7dfe961..9ccdaab99 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -341,11 +341,6 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) - - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) @@ -451,16 +446,13 @@ class ZipformerEncoderLayer(nn.Module): if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) - - src = src + self.feed_forward2(src) - if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn.forward2(src, attn_weights) if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.feed_forward3(src) + src = src + self.feed_forward2(src) src = self.norm_final(self.balancer(src)) @@ -855,7 +847,7 @@ class SimpleCombiner(torch.nn.Module): min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1 - self.weight1 = nn.Parameter(torch.zeros(())) + self.weight1 = nn.Parameter(torch.ones(dim2) * min_weight[0]) self.min_weight = min_weight def forward(self, @@ -878,9 +870,6 @@ class SimpleCombiner(torch.nn.Module): max=1.0-self.min_weight[1]) - src1 = src1 * weight1 - src2 = src2 * (1.0 - weight1) - src1_dim = src1.shape[-1] src2_dim = src2.shape[-1] if src1_dim != src2_dim: @@ -893,6 +882,8 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) return src1 + src2