diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py deleted file mode 100755 index ea7682994..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1,325 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang, -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX exported models and uses them to decode the test sets. - -We use the pre-trained model from -https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_2000/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -2. Run this file - -./zipformer/onnx_decode.py \ - --exp-dir $repo/exp \ - --max-duration 600 \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ -""" - - -import argparse -import logging -import time -from pathlib import Path -from typing import List, Tuple - -import torch -import torch.nn as nn -from asr_datamodule import AsrDataModule -from k2 import SymbolTable -from onnx_pretrained import OnnxModel, greedy_search - -from icefall.utils import setup_logger, store_transcripts, write_error_stats - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - - return parser - - -def decode_one_batch( - model: OnnxModel, token_table: SymbolTable, batch: dict -) -> List[List[str]]: - """Decode one batch and return the result. - Currently it only greedy_search is supported. - - Args: - model: - The neural model. - token_table: - The token table. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - - Returns: - Return the decoded results for each utterance. - """ - feature = batch["inputs"] - assert feature.ndim == 3 - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(dtype=torch.int64) - - encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) - - hyps = greedy_search( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - hyps = [token_ids_to_words(h).split() for h in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: nn.Module, - token_table: SymbolTable, -) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - The neural model. - token_table: - The token table. - - Returns: - - A list of tuples. Each tuple contains three elements: - - cut_id, - - reference transcript, - - predicted result. - - The total duration (in seconds) of the dataset. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 10 - total_duration = 0 - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return results, total_duration - - -def save_results( - res_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -): - recog_path = res_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - errs_info = res_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("WER", file=f) - print(wer, file=f) - - s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - - assert ( - args.decoding_method == "greedy_search" - ), "Only supports greedy_search currently." - res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" - - setup_logger(f"{res_dir}/log-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - logging.info(f"Device: {device}") - - token_table = SymbolTable.from_file(args.tokens) - - logging.info(vars(args)) - - logging.info("About to create model") - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = AsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - start_time = time.time() - results, total_duration = decode_dataset( - dl=test_dl, model=model, token_table=token_table - ) - end_time = time.time() - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - - logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") - logging.info(f"Wave duration: {total_duration:.3f} s") - logging.info( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - - save_results(res_dir=res_dir, test_set_name=test_set, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py deleted file mode 100755 index e8a521460..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1,419 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file - -./zipformer/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, joiner_dim) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.run_decoder(decoder_input) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - # current_encoder_out's shape: (batch_size, joiner_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - logits = model.run_joiner(current_encoder_out, decoder_out) - - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = model.run_decoder(decoder_input) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - - token_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main()