From 8f8b09498f91b5d73508beca130db4b088cb5932 Mon Sep 17 00:00:00 2001 From: danqing fu Date: Tue, 6 Jun 2023 10:05:59 +0800 Subject: [PATCH] remove if-branch at downsampling pad --- egs/librispeech/ASR/zipformer/zipformer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ea4e6711f..15022947f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1229,12 +1229,11 @@ class SimpleDownsample(torch.nn.Module): d_seq_len = (seq_len + ds - 1) // ds # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds src = src.reshape(d_seq_len, ds, batch_size, in_channels)