diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index bd0e625f0..7bda58669 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -509,12 +509,12 @@ class AttentionDownsample(torch.nn.Module): weights = scores.softmax(dim=1) # ans1 is the first `in_channels` channels of the output - ans1 = (src * weights).sum(dim=1) + ans = (src * weights).sum(dim=1) src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) if self.extra_proj is not None: ans2 = self.extra_proj(src) - ans = torch.cat((ans1, ans2), dim=2) + ans = torch.cat((ans, ans2), dim=2) return ans