Fix Blankskip bug.

The output of frame reducer may longer than input, fix it.
This commit is contained in:
kobenaxie 2023-12-07 20:59:24 +08:00 committed by GitHub
parent b87ed26c09
commit 4f379d6225
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -76,7 +76,7 @@ class FrameReducer(nn.Module):
if y_lens is not None:
# Limit the maximum number of reduced frames
limit_lens = T - y_lens
limit_lens = x_lens - y_lens
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
@ -97,6 +97,7 @@ class FrameReducer(nn.Module):
).scatter_(1, limit_indexes, 1)
non_blank_mask = non_blank_mask | ~limit_mask
non_blank_mask = non_blank_mask & ~padding_mask
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()