mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
remove if-branch at downsampling pad
This commit is contained in:
parent
f59e06c556
commit
8f8b09498f
@ -1229,12 +1229,11 @@ class SimpleDownsample(torch.nn.Module):
|
|||||||
d_seq_len = (seq_len + ds - 1) // ds
|
d_seq_len = (seq_len + ds - 1) // ds
|
||||||
|
|
||||||
# Pad to an exact multiple of self.downsample
|
# Pad to an exact multiple of self.downsample
|
||||||
if seq_len != d_seq_len * ds:
|
# right-pad src, repeating the last element.
|
||||||
# right-pad src, repeating the last element.
|
pad = d_seq_len * ds - seq_len
|
||||||
pad = d_seq_len * ds - seq_len
|
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
src = torch.cat((src, src_extra), dim=0)
|
||||||
src = torch.cat((src, src_extra), dim=0)
|
assert src.shape[0] == d_seq_len * ds
|
||||||
assert src.shape[0] == d_seq_len * ds
|
|
||||||
|
|
||||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user