mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Small fix for frame_reducer.py (#871)
This commit is contained in:
parent
bffce413f0
commit
029c8566e4
@ -44,6 +44,7 @@ class FrameReducer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
ctc_output: torch.Tensor,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -54,6 +55,8 @@ class FrameReducer(nn.Module):
|
||||
`x` before padding.
|
||||
ctc_output:
|
||||
The CTC output with shape [N, T, vocab_size].
|
||||
blank_id:
|
||||
The blank id of ctc_output.
|
||||
Returns:
|
||||
out:
|
||||
The frame reduced encoder output with shape [N, T', C].
|
||||
@ -65,7 +68,7 @@ class FrameReducer(nn.Module):
|
||||
N, T, C = x.size()
|
||||
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||
|
||||
out_lens = non_blank_mask.sum(dim=1)
|
||||
max_len = out_lens.max()
|
||||
|
Loading…
x
Reference in New Issue
Block a user