mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Small fix for frame_reducer.py (#871)
This commit is contained in:
parent
bffce413f0
commit
029c8566e4
@ -44,6 +44,7 @@ class FrameReducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
ctc_output: torch.Tensor,
|
ctc_output: torch.Tensor,
|
||||||
|
blank_id: int = 0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -54,6 +55,8 @@ class FrameReducer(nn.Module):
|
|||||||
`x` before padding.
|
`x` before padding.
|
||||||
ctc_output:
|
ctc_output:
|
||||||
The CTC output with shape [N, T, vocab_size].
|
The CTC output with shape [N, T, vocab_size].
|
||||||
|
blank_id:
|
||||||
|
The blank id of ctc_output.
|
||||||
Returns:
|
Returns:
|
||||||
out:
|
out:
|
||||||
The frame reduced encoder output with shape [N, T', C].
|
The frame reduced encoder output with shape [N, T', C].
|
||||||
@ -65,7 +68,7 @@ class FrameReducer(nn.Module):
|
|||||||
N, T, C = x.size()
|
N, T, C = x.size()
|
||||||
|
|
||||||
padding_mask = make_pad_mask(x_lens)
|
padding_mask = make_pad_mask(x_lens)
|
||||||
non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask)
|
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||||
|
|
||||||
out_lens = non_blank_mask.sum(dim=1)
|
out_lens = non_blank_mask.sum(dim=1)
|
||||||
max_len = out_lens.max()
|
max_len = out_lens.max()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user