diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index b33b712e4..1517a494f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -74,7 +74,7 @@ class FrameReducer(nn.Module): padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) - if y_lens is not None or self.training == False: + if y_lens is not None or self.training is False: # Limit the maximum number of reduced frames if y_lens is not None: limit_lens = T - y_lens @@ -85,15 +85,15 @@ class FrameReducer(nn.Module): fake_limit_indexes = torch.topk( ctc_output[:, :, blank_id], max_limit_len ).indices - T = ( + _T = ( torch.arange(max_limit_len) .expand_as( fake_limit_indexes, ) .to(device=x.device) ) - T = torch.remainder(T, limit_lens.unsqueeze(1)) - limit_indexes = torch.gather(fake_limit_indexes, 1, torch.tensor(T)) + _T = torch.remainder(_T, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, _T) limit_mask = (torch.full_like( non_blank_mask, 0,