diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 3de21a293..9fe88929d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -66,19 +66,14 @@ class FrameReducer(nn.Module): padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) - T_range = torch.arange(x.shape[1], device=x.device) frames_list: List[torch.Tensor] = [] lens_list: List[int] = [] for i in range(x.shape[0]): - indexes = torch.masked_select( - T_range, - non_blank_mask[i], - ) - frames = x[i][indexes] + frames = x[i][non_blank_mask[i]] frames_list.append(frames) lens_list.append(frames.shape[0]) - x_fr = pad_sequence(frames_list).transpose(0, 1) + x_fr = pad_sequence(frames_list, batch_first=True) x_lens_fr = torch.tensor(lens_list).to(device=x.device) return x_fr, x_lens_fr