From 52f3a747bec5d06ad0eba6a77f5abaf21110e5ec Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Feb 2023 12:12:26 +0800 Subject: [PATCH] Refactor onnx export for streaming zipformer (#879) --- ...nsducer-stateless7-streaming-2022-12-29.sh | 3 +- .github/scripts/test-onnx-export.sh | 70 ++ .github/workflows/test-onnx-export.yml | 75 ++ .../export-onnx.py | 639 ++++++++++++++++++ .../onnx_check.py | 267 ++++++++ .../onnx_model_wrapper.py | 2 +- .../onnx_pretrained.py | 510 ++++++++++++++ .../zipformer.py | 11 +- 8 files changed, 1572 insertions(+), 5 deletions(-) create mode 100755 .github/scripts/test-onnx-export.sh create mode 100644 .github/workflows/test-onnx-export.yml create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index bcbc91a44..2611c8efc 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -22,13 +22,14 @@ tree $repo/ soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav -pushd $repo/exp +pushd $repo git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/cpu_jit.pt" git lfs pull --include "exp/pretrained.pt" git lfs pull --include "exp/encoder_jit_trace.pt" git lfs pull --include "exp/decoder_jit_trace.pt" git lfs pull --include "exp/joiner_jit_trace.pt" +cd exp ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh new file mode 100755 index 000000000..20aa02950 --- /dev/null +++ b/.github/scripts/test-onnx-export.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + +log "==========================================================================" +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml new file mode 100644 index 000000000..c7729dedb --- /dev/null +++ b/.github/workflows/test-onnx-export.yml @@ -0,0 +1,75 @@ +name: test-onnx-export + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: test_onnx_export-${{ github.ref }} + cancel-in-progress: true + +jobs: + test_onnx_export: + if: github.event.label.name == 'ready' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Test ONNX export + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/test-onnx-export.sh diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 100755 index 000000000..a72472495 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1,639 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to use the exported models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + """Please see the help information of Zipformer.streaming_forward""" + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, device=x.device) + + output, _, new_states = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + output = self.encoder_proj(output) + # Now output is of shape (N, T, joiner_dim) + + return output, new_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """ + Onnx model inputs: + - 0: src + - many state tensors (the exact number depending on the actual model) + + Onnx model outputs: + - 0: output, its shape is (N, T, joiner_dim) + - many state tensors (the exact number depending on the actual model) + + Args: + encoder_model: + The model to be exported + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") + + x = torch.rand(1, T, 80, dtype=torch.float32) + + init_state = encoder_model.encoder.get_init_state() + + num_encoders = encoder_model.encoder.num_encoders + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, name, N): + for i, s in enumerate(tensors): + logging.info(f"{name}_{i}.shape: {s.shape}") + inputs[f"{name}_{i}"] = {N: "N"} + outputs[f"new_{name}_{i}"] = {N: "N"} + input_names.append(f"{name}_{i}") + output_names.append(f"new_{name}_{i}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) + attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) + ds = encoder_model.encoder.zipformer_downsampling_factors + left_context_len = encoder_model.encoder.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + + meta_data = { + "model_type": "streaming_zipformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "pad_length": str(pad_length), # 7 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "attention_dims": attention_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + } + logging.info(f"meta_data: {meta_data}") + + # (num_encoder_layers, 1) + cached_len = init_state[num_encoders * 0 : num_encoders * 1] + + # (num_encoder_layers, 1, encoder_dim) + cached_avg = init_state[num_encoders * 1 : num_encoders * 2] + + # (num_encoder_layers, left_context_len, 1, attention_dim) + cached_key = init_state[num_encoders * 2 : num_encoders * 3] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val = init_state[num_encoders * 3 : num_encoders * 4] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] + + build_inputs_outputs(cached_len, "cached_len", 1) + build_inputs_outputs(cached_avg, "cached_avg", 1) + build_inputs_outputs(cached_key, "cached_key", 2) + build_inputs_outputs(cached_val, "cached_val", 2) + build_inputs_outputs(cached_val2, "cached_val2", 2) + build_inputs_outputs(cached_conv1, "cached_conv1", 1) + build_inputs_outputs(cached_conv2, "cached_conv2", 1) + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "encoder_out": {0: "N", 1: "T"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + if params.use_averaged_model: + suffix += "-with-averaged-model" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 100755 index 000000000..72ad59a55 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script checks that exported ONNX models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torch.jit.trace() + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp + + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + +from onnx_pretrained import OnnxModel +from zipformer import stack_states + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + +import onnxruntime as ort +import torch + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-encoder-filename", + required=True, + type=str, + help="Path to the torchscript encoder model", + ) + + parser.add_argument( + "--jit-decoder-filename", + required=True, + type=str, + help="Path to the torchscript decoder model", + ) + + parser.add_argument( + "--jit-joiner-filename", + required=True, + type=str, + help="Path to the torchscript joiner model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the ONNX encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the ONNX decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the ONNX joiner model", + ) + + return parser + + +def test_encoder( + torch_encoder_model: torch.jit.ScriptModule, + torch_encoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + N = torch.randint(1, 100, size=(1,)).item() + T = onnx_model.segment + C = 80 + x_lens = torch.tensor([T] * N) + torch_states = [torch_encoder_model.get_init_state() for _ in range(N)] + torch_states = stack_states(torch_states) + + onnx_model.init_encoder_states(N) + + for i in range(5): + logging.info(f"test_encoder: iter {i}") + x = torch.rand(N, T, C) + torch_encoder_out, _, torch_states = torch_encoder_model( + x, x_lens, torch_states + ) + torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) + + onnx_encoder_out = onnx_model.run_encoder(x) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_decoder_model: torch.jit.ScriptModule, + torch_decoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_joiner_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] + decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) + projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) + + torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_encoder_model = torch.jit.load(args.jit_encoder_filename) + torch_decoder_model = torch.jit.load(args.jit_decoder_filename) + torch_joiner_model = torch.jit.load(args.jit_joiner_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + # When exporting the model to onnx, we have already put the encoder_proj + # inside the encoder. + test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) + + logging.info("Test decoder") + # When exporting the model to onnx, we have already put the decoder_proj + # inside the decoder. + test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_joiner_model, onnx_model) + + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20230207) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py index f52deecc9..71a418742 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -12,7 +12,7 @@ class OnnxStreamingEncoder(torch.nn.Module): def __init__(self, encoder): """ Args: - encoder: A Instance of Zipformer Class + encoder: An instance of Zipformer Class """ super().__init__() self.model = encoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 100755 index 000000000..265684d18 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1,510 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + pad_length = int(encoder_meta["pad_length"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + attention_dims = encoder_meta["attention_dims"] + cnn_module_kernels = encoder_meta["cnn_module_kernels"] + left_context_len = encoder_meta["left_context_len"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + attention_dims = to_int_list(attention_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"attention_dims: {attention_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + + num_encoders = len(num_encoder_layers) + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + N = batch_size + + for i in range(num_encoders): + cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64)) + cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i])) + cached_key.append( + torch.zeros( + num_encoder_layers[i], left_context_len[i], N, attention_dims[i] + ) + ) + cached_val.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_val2.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_conv1.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + cached_conv2.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + + self.cached_len = cached_len + self.cached_avg = cached_avg + self.cached_key = cached_key + self.cached_val = cached_val + self.cached_val2 = cached_val2 + self.cached_conv1 = cached_conv1 + self.cached_conv2 = cached_conv2 + + self.num_encoders = num_encoders + + self.segment = decode_chunk_len + pad_length + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_states_input(states: List[torch.Tensor], name: str): + for i, s in enumerate(states): + if isinstance(s, torch.Tensor): + encoder_input[f"{name}_{i}"] = s.numpy() + else: + encoder_input[f"{name}_{i}"] = s + + encoder_output.append(f"new_{name}_{i}") + + build_states_input(self.cached_len, "cached_len") + build_states_input(self.cached_avg, "cached_avg") + build_states_input(self.cached_key, "cached_key") + build_states_input(self.cached_val, "cached_val") + build_states_input(self.cached_val2, "cached_val2") + build_states_input(self.cached_conv1, "cached_conv1") + build_states_input(self.cached_conv2, "cached_conv2") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + num_encoders = self.num_encoders + + self.cached_len = states[num_encoders * 0 : num_encoders * 1] + self.cached_avg = states[num_encoders * 1 : num_encoders * 2] + self.cached_key = states[num_encoders * 2 : num_encoders * 3] + self.cached_val = states[num_encoders * 3 : num_encoders * 4] + self.cached_val2 = states[num_encoders * 4 : num_encoders * 5] + self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6] + self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7] + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + return torch.from_numpy(out) + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index 1b267c1c5..f7e52a9e6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -270,7 +270,7 @@ class Zipformer(EncoderInterface): dim_feedforward (int, int): feedforward dimension in 2 encoder stacks num_encoder_layers (int): number of encoder layers dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module + cnn_module_kernels (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. warmup_batches (float): number of batches to warm up over """ @@ -311,6 +311,8 @@ class Zipformer(EncoderInterface): # Used in decoding self.decode_chunk_size = decode_chunk_size + self.left_context_len = self.decode_chunk_size * self.num_left_chunks + # will be written to, see set_batch_count() self.batch_count = 0 self.warmup_end = warmup_batches @@ -330,7 +332,10 @@ class Zipformer(EncoderInterface): # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] + self.num_encoder_layers = num_encoder_layers self.num_encoders = len(encoder_dims) + self.attention_dims = attention_dim + self.cnn_module_kernels = cnn_module_kernels for i in range(self.num_encoders): encoder_layer = ZipformerEncoderLayer( encoder_dims[i], @@ -382,7 +387,7 @@ class Zipformer(EncoderInterface): def _init_skip_modules(self): """ - If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, we combine the outputs of layers 1 and 5. @@ -695,7 +700,7 @@ class Zipformer(EncoderInterface): num_layers = encoder.num_layers ds = self.zipformer_downsampling_factors[i] - len_avg = torch.zeros(num_layers, 1, dtype=torch.int32, device=device) + len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device) cached_len.append(len_avg) avg = torch.zeros(num_layers, 1, encoder.d_model, device=device)