From 5f066d3d538e093d72c94e48ad6be4c23da71f84 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 12 Apr 2023 19:04:50 +0800 Subject: [PATCH] support decoding and computing RTF on test sets with onnx models (#995) * support decode and compute RTF on test sets with onnx models * support onnx export and decode in pruned_transducer_stateless --- .../export-onnx.py | 527 ++++++++++++++++++ .../pruned_transducer_stateless/onnx_check.py | 1 + .../onnx_decode.py | 319 +++++++++++ .../onnx_pretrained.py | 1 + .../onnx_decode.py | 321 +++++++++++ .../onnx_pretrained.py | 3 +- .../export-onnx.py | 30 + .../onnx_decode.py | 326 +++++++++++ .../onnx_decode.py | 319 +++++++++++ .../ASR/transducer_stateless/conformer.py | 83 ++- 10 files changed, 1901 insertions(+), 29 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py new file mode 100755 index 000000000..a3ebe9d8c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -0,0 +1,527 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer""" + + def __init__(self, encoder: Conformer): + """ + Args: + encoder: + A Conformer encoder. + """ + super().__init__() + self.encoder = encoder + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder""" + + def __init__(self, decoder: Decoder): + super().__init__() + self.decoder = decoder + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + output = decoder_output.squeeze(1) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, inner_linear: nn.Linear, output_linear: nn.Linear): + super().__init__() + self.inner_linear = inner_linear + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.inner_linear(torch.tanh(logit)) + output = self.output_linear(nn.functional.relu(logit)) + + return output + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless3", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.inner_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + encoder = OnnxEncoder(encoder=model.encoder) + + decoder = OnnxDecoder(decoder=model.decoder) + + joiner = OnnxJoiner( + inner_linear=model.joiner.inner_linear, output_linear=model.joiner.output_linear + ) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py new file mode 120000 index 000000000..66d63b807 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py new file mode 100755 index 000000000..8134d43f8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file + +./pruned_transducer_stateless/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py new file mode 100755 index 000000000..3b1c72cf1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless3/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from librispeech import LibriSpeech + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 3947aa102..e10915086 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -71,7 +71,6 @@ from typing import List, Tuple import k2 import kaldifeat -import numpy as np import onnxruntime as ort import torch import torchaudio @@ -139,7 +138,7 @@ class OnnxModel: ): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 self.session_opts = session_opts diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 3d94760dc..e89d94d82 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -60,6 +60,7 @@ import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer +from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -568,6 +569,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py new file mode 100755 index 000000000..6f26e34b5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp/ \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless5/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py new file mode 100755 index 000000000..67585ee47 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" + +cd exp +ln -s pretrained-epoch-30-avg-9.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless7/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 94d0393c2..90b722bde 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -24,7 +24,7 @@ import torch from torch import Tensor, nn from transformer import Transformer -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask class Conformer(Transformer): @@ -154,7 +154,8 @@ class Conformer(Transformer): # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 - assert x.size(0) == lengths.max().item() + if not is_jit_tracing(): + assert x.size(0) == lengths.max().item() src_key_padding_mask = make_pad_mask(lengths) @@ -768,6 +769,14 @@ class RelPositionalEncoding(torch.nn.Module): def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 + self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) @@ -975,22 +984,34 @@ class RelPositionMultiheadAttention(nn.Module): the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape + time2 = time1 + left_context + if not is_jit_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, @@ -1061,13 +1082,16 @@ class RelPositionMultiheadAttention(nn.Module): """ tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + if not is_jit_tracing(): + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" + if not is_jit_tracing(): + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 if torch.equal(query, key) and torch.equal(key, value): @@ -1181,7 +1205,8 @@ class RelPositionMultiheadAttention(nn.Module): q = q.transpose(0, 1) # (batch, time1, head, d_k) pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 + if not is_jit_tracing(): + assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) @@ -1212,11 +1237,12 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] + if not is_jit_tracing(): + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] if attn_mask is not None: if attn_mask.dtype == torch.bool: @@ -1265,7 +1291,10 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + + if not is_jit_tracing(): + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) )