From 59eb465b3cd47a212117b535644f24ed190093e1 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 23 Dec 2022 17:55:36 +0800 Subject: [PATCH] optimize frame_reducer.py (#783) Co-authored-by: yifanyang --- .../pruned_transducer_stateless7_ctc_bs/frame_reducer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 3de21a293..9fe88929d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -66,19 +66,14 @@ class FrameReducer(nn.Module): padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) - T_range = torch.arange(x.shape[1], device=x.device) frames_list: List[torch.Tensor] = [] lens_list: List[int] = [] for i in range(x.shape[0]): - indexes = torch.masked_select( - T_range, - non_blank_mask[i], - ) - frames = x[i][indexes] + frames = x[i][non_blank_mask[i]] frames_list.append(frames) lens_list.append(frames.shape[0]) - x_fr = pad_sequence(frames_list).transpose(0, 1) + x_fr = pad_sequence(frames_list, batch_first=True) x_lens_fr = torch.tensor(lens_list).to(device=x.device) return x_fr, x_lens_fr