diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 4a19edf66..bc3fc57eb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -44,6 +44,7 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, + blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -54,6 +55,8 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. + blank_id: + The blank id of ctc_output. Returns: out: The frame reduced encoder output with shape [N, T', C]. @@ -65,7 +68,7 @@ class FrameReducer(nn.Module): N, T, C = x.size() padding_mask = make_pad_mask(x_lens) - non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max()