mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Refactor onnx export for streaming zipformer (#879)
This commit is contained in:
parent
5a05b95730
commit
52f3a747be
@ -22,13 +22,14 @@ tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/cpu_jit.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
git lfs pull --include "exp/encoder_jit_trace.pt"
|
||||
git lfs pull --include "exp/decoder_jit_trace.pt"
|
||||
git lfs pull --include "exp/joiner_jit_trace.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
ls -lh *.pt
|
||||
popd
|
||||
|
70
.github/scripts/test-onnx-export.sh
vendored
Executable file
70
.github/scripts/test-onnx-export.sh
vendored
Executable file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
|
||||
log "=========================================================================="
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
log "Export via torch.jit.trace()"
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
ls -lh $repo/exp
|
||||
|
||||
log "Run onnx_check.py"
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_check.py \
|
||||
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||
|
||||
log "Run onnx_pretrained.py"
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
rm -rf $repo
|
||||
log "--------------------------------------------------------------------------"
|
75
.github/workflows/test-onnx-export.yml
vendored
Normal file
75
.github/workflows/test-onnx-export.yml
vendored
Normal file
@ -0,0 +1,75 @@
|
||||
name: test-onnx-export
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
schedule:
|
||||
# minute (0-59)
|
||||
# hour (0-23)
|
||||
# day of the month (1-31)
|
||||
# month (1-12)
|
||||
# day of the week (0-6)
|
||||
# nightly build at 15:50 UTC time every day
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
concurrency:
|
||||
group: test_onnx_export-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test_onnx_export:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Test ONNX export
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/test-onnx-export.sh
|
639
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
Executable file
639
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
Executable file
@ -0,0 +1,639 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files in $repo/exp
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py for how to use the exported models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
|
||||
"""Please see the help information of Zipformer.streaming_forward"""
|
||||
N = x.size(0)
|
||||
T = x.size(1)
|
||||
x_lens = torch.tensor([T] * N, device=x.device)
|
||||
|
||||
output, _, new_states = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
output = self.encoder_proj(output)
|
||||
# Now output is of shape (N, T, joiner_dim)
|
||||
|
||||
return output, new_states
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""
|
||||
Onnx model inputs:
|
||||
- 0: src
|
||||
- many state tensors (the exact number depending on the actual model)
|
||||
|
||||
Onnx model outputs:
|
||||
- 0: output, its shape is (N, T, joiner_dim)
|
||||
- many state tensors (the exact number depending on the actual model)
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The model to be exported
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
)
|
||||
|
||||
decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
T = decode_chunk_len + pad_length
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"pad_length: {pad_length}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
|
||||
init_state = encoder_model.encoder.get_init_state()
|
||||
|
||||
num_encoders = encoder_model.encoder.num_encoders
|
||||
logging.info(f"num_encoders: {num_encoders}")
|
||||
logging.info(f"len(init_state): {len(init_state)}")
|
||||
|
||||
inputs = {}
|
||||
input_names = ["x"]
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, name, N):
|
||||
for i, s in enumerate(tensors):
|
||||
logging.info(f"{name}_{i}.shape: {s.shape}")
|
||||
inputs[f"{name}_{i}"] = {N: "N"}
|
||||
outputs[f"new_{name}_{i}"] = {N: "N"}
|
||||
input_names.append(f"{name}_{i}")
|
||||
output_names.append(f"new_{name}_{i}")
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims))
|
||||
attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims))
|
||||
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels))
|
||||
ds = encoder_model.encoder.zipformer_downsampling_factors
|
||||
left_context_len = encoder_model.encoder.left_context_len
|
||||
left_context_len = [left_context_len // k for k in ds]
|
||||
left_context_len = ",".join(map(str, left_context_len))
|
||||
|
||||
meta_data = {
|
||||
"model_type": "streaming_zipformer",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||
"pad_length": str(pad_length), # 7
|
||||
"num_encoder_layers": num_encoder_layers,
|
||||
"encoder_dims": encoder_dims,
|
||||
"attention_dims": attention_dims,
|
||||
"cnn_module_kernels": cnn_module_kernels,
|
||||
"left_context_len": left_context_len,
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
# (num_encoder_layers, 1)
|
||||
cached_len = init_state[num_encoders * 0 : num_encoders * 1]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim)
|
||||
cached_avg = init_state[num_encoders * 1 : num_encoders * 2]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim)
|
||||
cached_key = init_state[num_encoders * 2 : num_encoders * 3]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val = init_state[num_encoders * 3 : num_encoders * 4]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val2 = init_state[num_encoders * 4 : num_encoders * 5]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7]
|
||||
|
||||
build_inputs_outputs(cached_len, "cached_len", 1)
|
||||
build_inputs_outputs(cached_avg, "cached_avg", 1)
|
||||
build_inputs_outputs(cached_key, "cached_key", 2)
|
||||
build_inputs_outputs(cached_val, "cached_val", 2)
|
||||
build_inputs_outputs(cached_val2, "cached_val2", 2)
|
||||
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
|
||||
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
|
||||
|
||||
logging.info(inputs)
|
||||
logging.info(outputs)
|
||||
logging.info(input_names)
|
||||
logging.info(output_names)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
**inputs,
|
||||
**outputs,
|
||||
},
|
||||
)
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
if params.use_averaged_model:
|
||||
suffix += "-with-averaged-model"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
267
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
Executable file
267
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
Executable file
@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script checks that exported ONNX models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model via torch.jit.trace()
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp
|
||||
|
||||
- encoder_jit_trace.pt
|
||||
- decoder_jit_trace.pt
|
||||
- joiner_jit_trace.pt
|
||||
|
||||
3. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
4. Run this file
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_check.py \
|
||||
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from onnx_pretrained import OnnxModel
|
||||
from zipformer import stack_states
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX joiner model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
torch_encoder_model: torch.jit.ScriptModule,
|
||||
torch_encoder_proj_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
T = onnx_model.segment
|
||||
C = 80
|
||||
x_lens = torch.tensor([T] * N)
|
||||
torch_states = [torch_encoder_model.get_init_state() for _ in range(N)]
|
||||
torch_states = stack_states(torch_states)
|
||||
|
||||
onnx_model.init_encoder_states(N)
|
||||
|
||||
for i in range(5):
|
||||
logging.info(f"test_encoder: iter {i}")
|
||||
x = torch.rand(N, T, C)
|
||||
torch_encoder_out, _, torch_states = torch_encoder_model(
|
||||
x, x_lens, torch_states
|
||||
)
|
||||
torch_encoder_out = torch_encoder_proj_model(torch_encoder_out)
|
||||
|
||||
onnx_encoder_out = onnx_model.run_encoder(x)
|
||||
|
||||
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), (
|
||||
(torch_encoder_out - onnx_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
torch_decoder_model: torch.jit.ScriptModule,
|
||||
torch_decoder_proj_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
context_size = onnx_model.context_size
|
||||
vocab_size = onnx_model.vocab_size
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_decoder: iter {i}, N={N}")
|
||||
x = torch.randint(
|
||||
low=1,
|
||||
high=vocab_size,
|
||||
size=(N, context_size),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False]))
|
||||
torch_decoder_out = torch_decoder_proj_model(torch_decoder_out)
|
||||
torch_decoder_out = torch_decoder_out.squeeze(1)
|
||||
|
||||
onnx_decoder_out = onnx_model.run_decoder(x)
|
||||
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
|
||||
(torch_decoder_out - onnx_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
torch_joiner_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1]
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_joiner: iter {i}, N={N}")
|
||||
encoder_out = torch.rand(N, encoder_dim)
|
||||
decoder_out = torch.rand(N, decoder_dim)
|
||||
|
||||
projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out)
|
||||
projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out)
|
||||
|
||||
torch_joiner_out = torch_joiner_model(encoder_out, decoder_out)
|
||||
onnx_joiner_out = onnx_model.run_joiner(
|
||||
projected_encoder_out, projected_decoder_out
|
||||
)
|
||||
|
||||
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
|
||||
(torch_joiner_out - onnx_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_encoder_model = torch.jit.load(args.jit_encoder_filename)
|
||||
torch_decoder_model = torch.jit.load(args.jit_decoder_filename)
|
||||
torch_joiner_model = torch.jit.load(args.jit_joiner_filename)
|
||||
|
||||
onnx_model = OnnxModel(
|
||||
encoder_model_filename=args.onnx_encoder_filename,
|
||||
decoder_model_filename=args.onnx_decoder_filename,
|
||||
joiner_model_filename=args.onnx_joiner_filename,
|
||||
)
|
||||
|
||||
logging.info("Test encoder")
|
||||
# When exporting the model to onnx, we have already put the encoder_proj
|
||||
# inside the encoder.
|
||||
test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model)
|
||||
|
||||
logging.info("Test decoder")
|
||||
# When exporting the model to onnx, we have already put the decoder_proj
|
||||
# inside the decoder.
|
||||
test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model)
|
||||
|
||||
logging.info("Test joiner")
|
||||
test_joiner(torch_joiner_model, onnx_model)
|
||||
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/38342
|
||||
# and https://github.com/pytorch/pytorch/issues/33354
|
||||
#
|
||||
# If we don't do this, the delay increases whenever there is
|
||||
# a new request that changes the actual batch size.
|
||||
# If you use `py-spy dump --pid <server-pid> --native`, you will
|
||||
# see a lot of time is spent in re-compiling the torch script model.
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20230207)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -12,7 +12,7 @@ class OnnxStreamingEncoder(torch.nn.Module):
|
||||
def __init__(self, encoder):
|
||||
"""
|
||||
Args:
|
||||
encoder: A Instance of Zipformer Class
|
||||
encoder: An instance of Zipformer Class
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = encoder
|
||||
|
510
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
Executable file
510
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
Executable file
@ -0,0 +1,510 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx.py
|
||||
and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files in $repo/exp
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file with the exported ONNX models
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
the exported ONNX models do support batch processing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
def init_encoder_states(self, batch_size: int = 1):
|
||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
pad_length = int(encoder_meta["pad_length"])
|
||||
|
||||
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||
encoder_dims = encoder_meta["encoder_dims"]
|
||||
attention_dims = encoder_meta["attention_dims"]
|
||||
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
|
||||
left_context_len = encoder_meta["left_context_len"]
|
||||
|
||||
def to_int_list(s):
|
||||
return list(map(int, s.split(",")))
|
||||
|
||||
num_encoder_layers = to_int_list(num_encoder_layers)
|
||||
encoder_dims = to_int_list(encoder_dims)
|
||||
attention_dims = to_int_list(attention_dims)
|
||||
cnn_module_kernels = to_int_list(cnn_module_kernels)
|
||||
left_context_len = to_int_list(left_context_len)
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"pad_length: {pad_length}")
|
||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||
logging.info(f"encoder_dims: {encoder_dims}")
|
||||
logging.info(f"attention_dims: {attention_dims}")
|
||||
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
|
||||
logging.info(f"left_context_len: {left_context_len}")
|
||||
|
||||
num_encoders = len(num_encoder_layers)
|
||||
|
||||
cached_len = []
|
||||
cached_avg = []
|
||||
cached_key = []
|
||||
cached_val = []
|
||||
cached_val2 = []
|
||||
cached_conv1 = []
|
||||
cached_conv2 = []
|
||||
|
||||
N = batch_size
|
||||
|
||||
for i in range(num_encoders):
|
||||
cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64))
|
||||
cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i]))
|
||||
cached_key.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], left_context_len[i], N, attention_dims[i]
|
||||
)
|
||||
)
|
||||
cached_val.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i],
|
||||
left_context_len[i],
|
||||
N,
|
||||
attention_dims[i] // 2,
|
||||
)
|
||||
)
|
||||
cached_val2.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i],
|
||||
left_context_len[i],
|
||||
N,
|
||||
attention_dims[i] // 2,
|
||||
)
|
||||
)
|
||||
cached_conv1.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
|
||||
)
|
||||
)
|
||||
cached_conv2.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
|
||||
)
|
||||
)
|
||||
|
||||
self.cached_len = cached_len
|
||||
self.cached_avg = cached_avg
|
||||
self.cached_key = cached_key
|
||||
self.cached_val = cached_val
|
||||
self.cached_val2 = cached_val2
|
||||
self.cached_conv1 = cached_conv1
|
||||
self.cached_conv2 = cached_conv2
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
|
||||
self.segment = decode_chunk_len + pad_length
|
||||
self.offset = decode_chunk_len
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def _build_encoder_input_output(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
encoder_input = {"x": x.numpy()}
|
||||
encoder_output = ["encoder_out"]
|
||||
|
||||
def build_states_input(states: List[torch.Tensor], name: str):
|
||||
for i, s in enumerate(states):
|
||||
if isinstance(s, torch.Tensor):
|
||||
encoder_input[f"{name}_{i}"] = s.numpy()
|
||||
else:
|
||||
encoder_input[f"{name}_{i}"] = s
|
||||
|
||||
encoder_output.append(f"new_{name}_{i}")
|
||||
|
||||
build_states_input(self.cached_len, "cached_len")
|
||||
build_states_input(self.cached_avg, "cached_avg")
|
||||
build_states_input(self.cached_key, "cached_key")
|
||||
build_states_input(self.cached_val, "cached_val")
|
||||
build_states_input(self.cached_val2, "cached_val2")
|
||||
build_states_input(self.cached_conv1, "cached_conv1")
|
||||
build_states_input(self.cached_conv2, "cached_conv2")
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
num_encoders = self.num_encoders
|
||||
|
||||
self.cached_len = states[num_encoders * 0 : num_encoders * 1]
|
||||
self.cached_avg = states[num_encoders * 1 : num_encoders * 2]
|
||||
self.cached_key = states[num_encoders * 2 : num_encoders * 3]
|
||||
self.cached_val = states[num_encoders * 3 : num_encoders * 4]
|
||||
self.cached_val2 = states[num_encoders * 4 : num_encoders * 5]
|
||||
self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6]
|
||||
self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7]
|
||||
|
||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
return torch.from_numpy(out)
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
"""
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
encoder_out = encoder_out.squeeze(0)
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
waves = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
wave_samples = torch.cat([waves, tail_padding])
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
encoder_out = model.run_encoder(frames)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model,
|
||||
encoder_out,
|
||||
context_size,
|
||||
decoder_out,
|
||||
hyp,
|
||||
)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -270,7 +270,7 @@ class Zipformer(EncoderInterface):
|
||||
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
cnn_module_kernels (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
warmup_batches (float): number of batches to warm up over
|
||||
"""
|
||||
@ -311,6 +311,8 @@ class Zipformer(EncoderInterface):
|
||||
# Used in decoding
|
||||
self.decode_chunk_size = decode_chunk_size
|
||||
|
||||
self.left_context_len = self.decode_chunk_size * self.num_left_chunks
|
||||
|
||||
# will be written to, see set_batch_count()
|
||||
self.batch_count = 0
|
||||
self.warmup_end = warmup_batches
|
||||
@ -330,7 +332,10 @@ class Zipformer(EncoderInterface):
|
||||
# each one will be ZipformerEncoder or DownsampledZipformerEncoder
|
||||
encoders = []
|
||||
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.num_encoders = len(encoder_dims)
|
||||
self.attention_dims = attention_dim
|
||||
self.cnn_module_kernels = cnn_module_kernels
|
||||
for i in range(self.num_encoders):
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
encoder_dims[i],
|
||||
@ -382,7 +387,7 @@ class Zipformer(EncoderInterface):
|
||||
|
||||
def _init_skip_modules(self):
|
||||
"""
|
||||
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||
If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
|
||||
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
|
||||
we combine the outputs of layers 1 and 5.
|
||||
@ -695,7 +700,7 @@ class Zipformer(EncoderInterface):
|
||||
num_layers = encoder.num_layers
|
||||
ds = self.zipformer_downsampling_factors[i]
|
||||
|
||||
len_avg = torch.zeros(num_layers, 1, dtype=torch.int32, device=device)
|
||||
len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device)
|
||||
cached_len.append(len_avg)
|
||||
|
||||
avg = torch.zeros(num_layers, 1, encoder.d_model, device=device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user