From 029c8566e424b44e64da70c6fb532caace9c7d54 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 3 Feb 2023 17:49:54 +0800 Subject: [PATCH] Small fix for frame_reducer.py (#871) --- .../ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 4a19edf66..bc3fc57eb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -44,6 +44,7 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, + blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -54,6 +55,8 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. + blank_id: + The blank id of ctc_output. Returns: out: The frame reduced encoder output with shape [N, T', C]. @@ -65,7 +68,7 @@ class FrameReducer(nn.Module): N, T, C = x.size() padding_mask = make_pad_mask(x_lens) - non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max()