mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
Fix Blankskip bug.
The output of frame reducer may longer than input, fix it.
This commit is contained in:
parent
b87ed26c09
commit
4f379d6225
@ -76,7 +76,7 @@ class FrameReducer(nn.Module):
|
|||||||
|
|
||||||
if y_lens is not None:
|
if y_lens is not None:
|
||||||
# Limit the maximum number of reduced frames
|
# Limit the maximum number of reduced frames
|
||||||
limit_lens = T - y_lens
|
limit_lens = x_lens - y_lens
|
||||||
max_limit_len = limit_lens.max().int()
|
max_limit_len = limit_lens.max().int()
|
||||||
fake_limit_indexes = torch.topk(
|
fake_limit_indexes = torch.topk(
|
||||||
ctc_output[:, :, blank_id], max_limit_len
|
ctc_output[:, :, blank_id], max_limit_len
|
||||||
@ -97,6 +97,7 @@ class FrameReducer(nn.Module):
|
|||||||
).scatter_(1, limit_indexes, 1)
|
).scatter_(1, limit_indexes, 1)
|
||||||
|
|
||||||
non_blank_mask = non_blank_mask | ~limit_mask
|
non_blank_mask = non_blank_mask | ~limit_mask
|
||||||
|
non_blank_mask = non_blank_mask & ~padding_mask
|
||||||
|
|
||||||
out_lens = non_blank_mask.sum(dim=1)
|
out_lens = non_blank_mask.sum(dim=1)
|
||||||
max_len = out_lens.max()
|
max_len = out_lens.max()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user