Small fix for frame_reducer.py (#871)

This commit is contained in:
Yifan Yang 2023-02-03 17:49:54 +08:00 committed by GitHub
parent bffce413f0
commit 029c8566e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -44,6 +44,7 @@ class FrameReducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
ctc_output: torch.Tensor,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -54,6 +55,8 @@ class FrameReducer(nn.Module):
`x` before padding.
ctc_output:
The CTC output with shape [N, T, vocab_size].
blank_id:
The blank id of ctc_output.
Returns:
out:
The frame reduced encoder output with shape [N, T', C].
@ -65,7 +68,7 @@ class FrameReducer(nn.Module):
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()