mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge 8809c7a99115277a2d4a77c69714ada661ba6653 into abd9437e6d5419a497707748eb935e50976c3b7b
This commit is contained in:
commit
57214969b1
@ -76,7 +76,7 @@ class FrameReducer(nn.Module):
|
|||||||
|
|
||||||
if y_lens is not None:
|
if y_lens is not None:
|
||||||
# Limit the maximum number of reduced frames
|
# Limit the maximum number of reduced frames
|
||||||
limit_lens = T - y_lens
|
limit_lens = x_lens - y_lens
|
||||||
max_limit_len = limit_lens.max().int()
|
max_limit_len = limit_lens.max().int()
|
||||||
fake_limit_indexes = torch.topk(
|
fake_limit_indexes = torch.topk(
|
||||||
ctc_output[:, :, blank_id], max_limit_len
|
ctc_output[:, :, blank_id], max_limit_len
|
||||||
@ -97,6 +97,7 @@ class FrameReducer(nn.Module):
|
|||||||
).scatter_(1, limit_indexes, 1)
|
).scatter_(1, limit_indexes, 1)
|
||||||
|
|
||||||
non_blank_mask = non_blank_mask | ~limit_mask
|
non_blank_mask = non_blank_mask | ~limit_mask
|
||||||
|
non_blank_mask = non_blank_mask & ~padding_mask
|
||||||
|
|
||||||
out_lens = non_blank_mask.sum(dim=1)
|
out_lens = non_blank_mask.sum(dim=1)
|
||||||
max_len = out_lens.max()
|
max_len = out_lens.max()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user