diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index c44cb1eaf..b218c9f94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -76,7 +76,7 @@ class FrameReducer(nn.Module): if y_lens is not None: # Limit the maximum number of reduced frames - limit_lens = T - y_lens + limit_lens = x_lens - y_lens max_limit_len = limit_lens.max().int() fake_limit_indexes = torch.topk( ctc_output[:, :, blank_id], max_limit_len @@ -97,6 +97,7 @@ class FrameReducer(nn.Module): ).scatter_(1, limit_indexes, 1) non_blank_mask = non_blank_mask | ~limit_mask + non_blank_mask = non_blank_mask & ~padding_mask out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max()