From d5ad908562f432df54ec4592f89e3218bab8940f Mon Sep 17 00:00:00 2001 From: drawfish Date: Mon, 29 May 2023 16:32:49 +0800 Subject: [PATCH] Add support for export.py --- .../frame_reducer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 671b7565f..b33b712e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -74,7 +74,7 @@ class FrameReducer(nn.Module): padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) - if y_lens is not None or self.training is False: + if y_lens is not None or self.training == False: # Limit the maximum number of reduced frames if y_lens is not None: limit_lens = T - y_lens @@ -93,12 +93,12 @@ class FrameReducer(nn.Module): .to(device=x.device) ) T = torch.remainder(T, limit_lens.unsqueeze(1)) - limit_indexes = torch.gather(fake_limit_indexes, 1, T) - limit_mask = torch.full_like( + limit_indexes = torch.gather(fake_limit_indexes, 1, torch.tensor(T)) + limit_mask = (torch.full_like( non_blank_mask, - False, + 0, device=x.device, - ).scatter_(1, limit_indexes, True) + ).scatter_(1, limit_indexes, 1) == 1) non_blank_mask = non_blank_mask | ~limit_mask @@ -112,7 +112,7 @@ class FrameReducer(nn.Module): ) - out_lens ) - max_pad_len = pad_lens_list.max() + max_pad_len = int(pad_lens_list.max().item()) out = F.pad(x, (0, 0, 0, max_pad_len))