mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
remove if-branch at downsample pad in zipformer for onnx-export compatibility (#965)
This commit is contained in:
parent
d74822d07b
commit
f260a09ed4
@ -781,13 +781,12 @@ class AttentionDownsample(torch.nn.Module):
|
||||
ds = self.downsample
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
|
||||
# Pad to an exact multiple of self.downsample
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds)
|
||||
# Pad to an exact multiple of self.downsample, could be 0 for onnx-export-compatibility
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds)
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user