From 4f379d62253ee14639aa6977dc14ead4c7e7bdc1 Mon Sep 17 00:00:00 2001 From: kobenaxie <572745565@qq.com> Date: Thu, 7 Dec 2023 20:59:24 +0800 Subject: [PATCH] Fix Blankskip bug. The output of frame reducer may longer than input, fix it. --- .../ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index c44cb1eaf..b218c9f94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -76,7 +76,7 @@ class FrameReducer(nn.Module): if y_lens is not None: # Limit the maximum number of reduced frames - limit_lens = T - y_lens + limit_lens = x_lens - y_lens max_limit_len = limit_lens.max().int() fake_limit_indexes = torch.topk( ctc_output[:, :, blank_id], max_limit_len @@ -97,6 +97,7 @@ class FrameReducer(nn.Module): ).scatter_(1, limit_indexes, 1) non_blank_mask = non_blank_mask | ~limit_mask + non_blank_mask = non_blank_mask & ~padding_mask out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max()