From bf5f0342a24b2dd92a908980ba5e8619ca2a08f4 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 6 Feb 2023 10:37:07 +0800 Subject: [PATCH 1/2] Add streaming onnx export for zipformer (#831) * add streaming onnx export for zipformer * update triton support * add comments * add ci test * add onnxmltools for fp16 onnx export --- ...nsducer-stateless7-streaming-2022-12-29.sh | 10 + ...speech-2022-12-29-stateless7-streaming.yml | 2 +- .../export.py | 563 +++++++++++++++++- .../onnx_model_wrapper.py | 231 +++++++ .../zipformer.py | 60 +- requirements-ci.txt | 1 + 6 files changed, 843 insertions(+), 24 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index afb0dc05a..bcbc91a44 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -33,6 +33,16 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd +log "Test exporting to ONNX format" +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --fp16 \ + --onnx 1 + log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml index 6dd93946a..a1f3b4f75 100644 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_12_29_zipformer_streaming: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 5c06cc052..1bc54fa26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -72,25 +72,81 @@ Check ./pretrained.py for its usage. Note: If you don't want to train a model from scratch, we have provided one for you. You can get it at -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 with the following commands: sudo apt-get install git-lfs git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp + +(3) Export to ONNX format with pretrained.pt + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(4) Export to ONNX format for triton server + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx-triton 1 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + +Check +https://github.com/k2-fsa/sherpa/tree/master/triton +for how to use the exported models outside of icefall. + """ + import argparse import logging from pathlib import Path +import onnxruntime import sentencepiece as spm import torch import torch.nn as nn +from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states from icefall.checkpoint import ( average_checkpoints, @@ -172,6 +228,42 @@ def get_parser(): """, ) + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--onnx-triton", + type=str2bool, + default=False, + help="""If True, --onnx would export model into the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton. + """, + ) + + parser.add_argument( + "--fp16", + action="store_true", + help="whether to export fp16 onnx model, default false", + ) + parser.add_argument( "--context-size", type=int, @@ -184,6 +276,391 @@ def get_parser(): return parser +def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): + for a, b in zip(xlist, blist): + try: + torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + except AssertionError as error: + if tolerate_small_mismatch: + print("small mismatch detected", error) + else: + return False + return True + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + batch_size = 17 + seq_len = 101 + torch.manual_seed(0) + x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32) + x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + initial_states = [encoder_model.get_init_state() for _ in range(batch_size)] + states = stack_states(initial_states) + + left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks + encoder_attention_dim = encoder_model.encoders[0].attention_dim + + len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15 + avg_cache = torch.cat( + states[encoder_model.num_encoders : 2 * encoder_model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dim - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders + ] + ] + attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + encoder_model_wrapper = OnnxStreamingEncoder(encoder_model) + + torch.onnx.export( + encoder_model_wrapper, + (x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "x", + "x_lens", + "len_cache", + "avg_cache", + "attn_cache", + "cnn_cache", + ], + output_names=[ + "encoder_out", + "encoder_out_lens", + "new_len_cache", + "new_avg_cache", + "new_attn_cache", + "new_cnn_cache", + ], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + "len_cache": {0: "N"}, + "avg_cache": {0: "N"}, + "attn_cache": {0: "N"}, + "cnn_cache": {0: "N"}, + "new_len_cache": {0: "N"}, + "new_avg_cache": {0: "N"}, + "new_attn_cache": {0: "N"}, + "new_cnn_cache": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + # Test onnx encoder with torch native encoder + encoder_model.eval() + ( + encoder_out_torch, + encoder_out_lens_torch, + new_states_torch, + ) = encoder_model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + ort_session = onnxruntime.InferenceSession( + str(encoder_filename), providers=["CPUExecutionProvider"] + ) + ort_inputs = { + "x": x.numpy(), + "x_lens": x_lens.numpy(), + "len_cache": len_cache.numpy(), + "avg_cache": avg_cache.numpy(), + "attn_cache": attn_cache.numpy(), + "cnn_cache": cnn_cache.numpy(), + } + ort_outs = ort_session.run(None, ort_inputs) + + assert test_acc( + [encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2] + ) + logging.info(f"{encoder_filename} acc test succeeded.") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_decoder_model_onnx_triton( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + + decoder_model = TritonOnnxDecoder(decoder_model) + + torch.onnx.export( + decoder_model, + (y,), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_joiner_model_onnx_triton( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported model has two inputs: + - encoder_out: a tensor of shape (N, encoder_out_dim) + - decoder_out: a tensor of shape (N, decoder_out_dim) + and has one output: + - joiner_out: a tensor of shape (N, vocab_size) + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + joiner_model = TritonOnnxJoiner(joiner_model) + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (encoder_out, decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out", "decoder_out"], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -292,7 +769,87 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.onnx: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + if not params.onnx_triton: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + else: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx_triton( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx_triton( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + if params.fp16: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + except ImportError: + print("Please install onnxmltools!") + import sys + + sys.exit(1) + + def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_fp16_filename) + + decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_fp16_filename) + + joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_fp16_filename) + + if not params.onnx_triton: + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + encoder_proj_fp16_filename = ( + params.exp_dir / "joiner_encoder_proj_fp16.onnx" + ) + export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + decoder_proj_fp16_filename = ( + params.exp_dir / "joiner_decoder_proj_fp16.onnx" + ) + export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename) + + elif params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 100644 index 000000000..f52deecc9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1,231 @@ +from typing import Optional, Tuple + +import torch + + +class OnnxStreamingEncoder(torch.nn.Module): + """This class warps the streaming Zipformer to reduce the number of + state tensors for onnx. + https://github.com/k2-fsa/icefall/pull/831 + """ + + def __init__(self, encoder): + """ + Args: + encoder: A Instance of Zipformer Class + """ + super().__init__() + self.model = encoder + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + len_cache: torch.tensor, + avg_cache: torch.tensor, + attn_cache: torch.tensor, + cnn_cache: torch.tensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + len_cache: + The cached numbers of past frames. + avg_cache: + The cached average tensors. + attn_cache: + The cached key tensors of the first attention modules. + The cached value tensors of the first attention modules. + The cached value tensors of the second attention modules. + cnn_cache: + The cached left contexts of the first convolution modules. + The cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 2 tensors: + + """ + num_encoder_layers = [] + encoder_attention_dims = [] + states = [] + for i, encoder in enumerate(self.model.encoders): + num_encoder_layers.append(encoder.num_layers) + encoder_attention_dims.append(encoder.attention_dim) + + len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B] + offset = 0 + for num_layer in num_encoder_layers: + states.append(len_cache[offset : offset + num_layer]) + offset += num_layer + + avg_cache = avg_cache.transpose(0, 1) # [15, B, 384] + offset = 0 + for num_layer in num_encoder_layers: + states.append(avg_cache[offset : offset + num_layer]) + offset += num_layer + + attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192] + left_context_len = attn_cache.shape[1] + offset = 0 + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[offset : offset + num_layer, : left_context_len // ds] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + encoder_attention_dim = encoder_attention_dims[i] + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + + cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1] + offset = 0 + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + + encoder_out, encoder_out_lens, new_states = self.model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose( + 0, 1 + ) # [B,15] + new_avg_cache = torch.cat( + states[self.model.num_encoders : 2 * self.model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + assert len(set(encoder_attention_dims)) == 1 + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dims[0] - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * self.model.num_encoders : 5 * self.model.num_encoders + ] + ] + new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + return ( + encoder_out, + encoder_out_lens, + new_len_cache, + new_avg_cache, + new_attn_cache, + new_cnn_cache, + ) + + +class TritonOnnxDecoder(torch.nn.Module): + """This class warps the Decoder in decoder.py + to remove the scalar input "need_pad". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + """ + + def __init__( + self, + decoder: torch.nn.Module, + ): + """ + Args: + decoder: A instance of Decoder + """ + super().__init__() + self.model = decoder + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + # False to not pad the input. Should be False during inference. + need_pad = False + return self.model(y, need_pad) + + +class TritonOnnxJoiner(torch.nn.Module): + """This class warps the Joiner in joiner.py + to remove the scalar input "project_input". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + "project_input" is set to True. + Triton solutions only need export joiner to a single joiner.onnx. + """ + + def __init__( + self, + joiner: torch.nn.Module, + ): + super().__init__() + self.model = joiner + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + # Apply input projections encoder_proj and decoder_proj. + project_input = True + return self.model(encoder_out, decoder_out, project_input) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index e13629384..1b267c1c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -2084,16 +2084,26 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -2275,16 +2285,26 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(kv_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights diff --git a/requirements-ci.txt b/requirements-ci.txt index b8e49899e..50d4e5e3f 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -22,5 +22,6 @@ typeguard==2.13.3 multi_quantization onnx +onnxmltools onnxruntime kaldifst From caf23546edea120f402b03916d3a5647f54a28d8 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 6 Feb 2023 12:17:45 +0800 Subject: [PATCH 2/2] No more T < S after frame_reducer (#875) * No more T < S after frame_reducer * Fix for style check * Adjust the permissions * Add support for inference to frame_reducer * Fix for flake8 check --------- Co-authored-by: yifanyang --- .../__init__.py | 0 .../export_onnx.py | 0 .../frame_reducer.py | 74 +++++++++++++++---- .../lconv.py | 0 .../model.py | 10 ++- .../onnx_pretrained.py | 0 .../train.py | 3 +- 7 files changed, 65 insertions(+), 22 deletions(-) mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py mode change 100644 => 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py mode change 100644 => 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index bc3fc57eb..0841f7cf1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, -# Zengwei Yao) +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,7 +19,7 @@ # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn @@ -44,6 +45,7 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, + y_lens: Optional[torch.Tensor] = None, blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -55,6 +57,9 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. + y_lens: + A tensor of shape (batch_size,) containing the number of frames in + `y` before padding. blank_id: The blank id of ctc_output. Returns: @@ -64,15 +69,45 @@ class FrameReducer(nn.Module): A tensor of shape (batch_size,) containing the number of frames in `out` before padding. """ - N, T, C = x.size() padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + if y_lens is not None: + # Limit the maximum number of reduced frames + limit_lens = T - y_lens + max_limit_len = limit_lens.max().int() + fake_limit_indexes = torch.topk( + ctc_output[:, :, blank_id], max_limit_len + ).indices + T = ( + torch.arange(max_limit_len) + .expand_as( + fake_limit_indexes, + ) + .to(device=x.device) + ) + T = torch.remainder(T, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, T) + limit_mask = torch.full_like( + non_blank_mask, + False, + device=x.device, + ).scatter_(1, limit_indexes, True) + + non_blank_mask = non_blank_mask | ~limit_mask + out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max() - pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens + pad_lens_list = ( + torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) + - out_lens + ) max_pad_len = pad_lens_list.max() out = F.pad(x, (0, 0, 0, max_pad_len)) @@ -82,26 +117,30 @@ class FrameReducer(nn.Module): out = out[total_valid_mask].reshape(N, -1, C) - return out.to(device=x.device), out_lens.to(device=x.device) + return out, out_lens if __name__ == "__main__": import time - from torch.nn.utils.rnn import pad_sequence test_times = 10000 + device = "cuda:0" frame_reducer = FrameReducer() # non zero case - x = torch.ones(15, 498, 384, dtype=torch.float32) - x_lens = torch.tensor([498] * 15, dtype=torch.int64) - ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32)) - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x = torch.ones(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.log( + torch.randn(15, 498, 500, dtype=torch.float32, device=device), + ) avg_time = 0 for i in range(test_times): + torch.cuda.synchronize(device=x.device) delta_time = time.time() - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) delta_time = time.time() - delta_time avg_time += delta_time print(x_fr.shape) @@ -109,14 +148,17 @@ if __name__ == "__main__": print(avg_time / test_times) # all zero case - x = torch.zeros(15, 498, 384, dtype=torch.float32) - x_lens = torch.tensor([498] * 15, dtype=torch.int64) - ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32) + x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device) avg_time = 0 for i in range(test_times): + torch.cuda.synchronize(device=x.device) delta_time = time.time() - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) delta_time = time.time() - delta_time avg_time += delta_time print(x_fr.shape) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py old mode 100755 new mode 100644 index 86acc5a10..0582b289f --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -131,6 +131,10 @@ class Transducer(nn.Module): # compute ctc log-probs ctc_output = self.ctc_output(encoder_out) + # y_lens + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + # blank skip blank_id = self.decoder.blank_id @@ -146,16 +150,14 @@ class Transducer(nn.Module): encoder_out, x_lens, ctc_output, + y_lens, blank_id, ) else: encoder_out_fr = encoder_out x_lens_fr = x_lens - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - + # sos_y sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index b282ab9db..ea280e642 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, @@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --use-fp16 1 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ --full-libri 1 \ - --max-duration 550 + --max-duration 750 """