diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 3959c0bb2..5b75b8d35 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -781,13 +781,12 @@ class AttentionDownsample(torch.nn.Module): ds = self.downsample d_seq_len = (seq_len + ds - 1) // ds - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + # Pad to an exact multiple of self.downsample, could be 0 for onnx-export-compatibility + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True)