From caf23546edea120f402b03916d3a5647f54a28d8 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 6 Feb 2023 12:17:45 +0800 Subject: [PATCH] No more T < S after frame_reducer (#875) * No more T < S after frame_reducer * Fix for style check * Adjust the permissions * Add support for inference to frame_reducer * Fix for flake8 check --------- Co-authored-by: yifanyang --- .../__init__.py | 0 .../export_onnx.py | 0 .../frame_reducer.py | 74 +++++++++++++++---- .../lconv.py | 0 .../model.py | 10 ++- .../onnx_pretrained.py | 0 .../train.py | 3 +- 7 files changed, 65 insertions(+), 22 deletions(-) mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py mode change 100644 => 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py mode change 100644 => 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index bc3fc57eb..0841f7cf1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, -# Zengwei Yao) +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,7 +19,7 @@ # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn @@ -44,6 +45,7 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, + y_lens: Optional[torch.Tensor] = None, blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -55,6 +57,9 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. + y_lens: + A tensor of shape (batch_size,) containing the number of frames in + `y` before padding. blank_id: The blank id of ctc_output. Returns: @@ -64,15 +69,45 @@ class FrameReducer(nn.Module): A tensor of shape (batch_size,) containing the number of frames in `out` before padding. """ - N, T, C = x.size() padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + if y_lens is not None: + # Limit the maximum number of reduced frames + limit_lens = T - y_lens + max_limit_len = limit_lens.max().int() + fake_limit_indexes = torch.topk( + ctc_output[:, :, blank_id], max_limit_len + ).indices + T = ( + torch.arange(max_limit_len) + .expand_as( + fake_limit_indexes, + ) + .to(device=x.device) + ) + T = torch.remainder(T, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, T) + limit_mask = torch.full_like( + non_blank_mask, + False, + device=x.device, + ).scatter_(1, limit_indexes, True) + + non_blank_mask = non_blank_mask | ~limit_mask + out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max() - pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens + pad_lens_list = ( + torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) + - out_lens + ) max_pad_len = pad_lens_list.max() out = F.pad(x, (0, 0, 0, max_pad_len)) @@ -82,26 +117,30 @@ class FrameReducer(nn.Module): out = out[total_valid_mask].reshape(N, -1, C) - return out.to(device=x.device), out_lens.to(device=x.device) + return out, out_lens if __name__ == "__main__": import time - from torch.nn.utils.rnn import pad_sequence test_times = 10000 + device = "cuda:0" frame_reducer = FrameReducer() # non zero case - x = torch.ones(15, 498, 384, dtype=torch.float32) - x_lens = torch.tensor([498] * 15, dtype=torch.int64) - ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32)) - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x = torch.ones(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.log( + torch.randn(15, 498, 500, dtype=torch.float32, device=device), + ) avg_time = 0 for i in range(test_times): + torch.cuda.synchronize(device=x.device) delta_time = time.time() - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) delta_time = time.time() - delta_time avg_time += delta_time print(x_fr.shape) @@ -109,14 +148,17 @@ if __name__ == "__main__": print(avg_time / test_times) # all zero case - x = torch.zeros(15, 498, 384, dtype=torch.float32) - x_lens = torch.tensor([498] * 15, dtype=torch.int64) - ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32) + x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device) avg_time = 0 for i in range(test_times): + torch.cuda.synchronize(device=x.device) delta_time = time.time() - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) delta_time = time.time() - delta_time avg_time += delta_time print(x_fr.shape) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py old mode 100755 new mode 100644 index 86acc5a10..0582b289f --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -131,6 +131,10 @@ class Transducer(nn.Module): # compute ctc log-probs ctc_output = self.ctc_output(encoder_out) + # y_lens + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + # blank skip blank_id = self.decoder.blank_id @@ -146,16 +150,14 @@ class Transducer(nn.Module): encoder_out, x_lens, ctc_output, + y_lens, blank_id, ) else: encoder_out_fr = encoder_out x_lens_fr = x_lens - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - + # sos_y sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index b282ab9db..ea280e642 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, @@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --use-fp16 1 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ --full-libri 1 \ - --max-duration 550 + --max-duration 750 """