mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 02:52:18 +00:00
Merge branch 'k2-fsa:master' into repeat-k
This commit is contained in:
commit
51812b2ddc
@ -33,6 +33,16 @@ ln -s pretrained.pt epoch-99.pt
|
||||
ls -lh *.pt
|
||||
popd
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model false \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--fp16 \
|
||||
--onnx 1
|
||||
|
||||
log "Export to torchscript model"
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
|
@ -39,7 +39,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
run_librispeech_2022_12_29_zipformer_streaming:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py
Normal file → Executable file
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
||||
# Zengwei Yao)
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
||||
# Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -18,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -44,6 +45,7 @@ class FrameReducer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
ctc_output: torch.Tensor,
|
||||
y_lens: Optional[torch.Tensor] = None,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -55,6 +57,9 @@ class FrameReducer(nn.Module):
|
||||
`x` before padding.
|
||||
ctc_output:
|
||||
The CTC output with shape [N, T, vocab_size].
|
||||
y_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`y` before padding.
|
||||
blank_id:
|
||||
The blank id of ctc_output.
|
||||
Returns:
|
||||
@ -64,15 +69,45 @@ class FrameReducer(nn.Module):
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`out` before padding.
|
||||
"""
|
||||
|
||||
N, T, C = x.size()
|
||||
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||
|
||||
if y_lens is not None:
|
||||
# Limit the maximum number of reduced frames
|
||||
limit_lens = T - y_lens
|
||||
max_limit_len = limit_lens.max().int()
|
||||
fake_limit_indexes = torch.topk(
|
||||
ctc_output[:, :, blank_id], max_limit_len
|
||||
).indices
|
||||
T = (
|
||||
torch.arange(max_limit_len)
|
||||
.expand_as(
|
||||
fake_limit_indexes,
|
||||
)
|
||||
.to(device=x.device)
|
||||
)
|
||||
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
||||
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
||||
limit_mask = torch.full_like(
|
||||
non_blank_mask,
|
||||
False,
|
||||
device=x.device,
|
||||
).scatter_(1, limit_indexes, True)
|
||||
|
||||
non_blank_mask = non_blank_mask | ~limit_mask
|
||||
|
||||
out_lens = non_blank_mask.sum(dim=1)
|
||||
max_len = out_lens.max()
|
||||
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
|
||||
pad_lens_list = (
|
||||
torch.full_like(
|
||||
out_lens,
|
||||
max_len.item(),
|
||||
device=x.device,
|
||||
)
|
||||
- out_lens
|
||||
)
|
||||
max_pad_len = pad_lens_list.max()
|
||||
|
||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||
@ -82,26 +117,30 @@ class FrameReducer(nn.Module):
|
||||
|
||||
out = out[total_valid_mask].reshape(N, -1, C)
|
||||
|
||||
return out.to(device=x.device), out_lens.to(device=x.device)
|
||||
return out, out_lens
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
test_times = 10000
|
||||
device = "cuda:0"
|
||||
frame_reducer = FrameReducer()
|
||||
|
||||
# non zero case
|
||||
x = torch.ones(15, 498, 384, dtype=torch.float32)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||
ctc_output = torch.log(
|
||||
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
|
||||
)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
@ -109,14 +148,17 @@ if __name__ == "__main__":
|
||||
print(avg_time / test_times)
|
||||
|
||||
# all zero case
|
||||
x = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
|
||||
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
Executable file → Normal file
10
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
Executable file → Normal file
10
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
Executable file → Normal file
@ -131,6 +131,10 @@ class Transducer(nn.Module):
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
|
||||
# y_lens
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
# blank skip
|
||||
blank_id = self.decoder.blank_id
|
||||
|
||||
@ -146,16 +150,14 @@ class Transducer(nn.Module):
|
||||
encoder_out,
|
||||
x_lens,
|
||||
ctc_output,
|
||||
y_lens,
|
||||
blank_id,
|
||||
)
|
||||
else:
|
||||
encoder_out_fr = encoder_out
|
||||
x_lens_fr = x_lens
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
# sos_y
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
Normal file → Executable file
@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,
|
||||
@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
--max-duration 750
|
||||
"""
|
||||
|
||||
|
||||
|
@ -72,25 +72,81 @@ Check ./pretrained.py for its usage.
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
||||
|
||||
(3) Export to ONNX format with pretrained.pt
|
||||
|
||||
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--fp16 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(4) Export to ONNX format for triton server
|
||||
|
||||
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--fp16 \
|
||||
--onnx-triton 1 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa/tree/master/triton
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnxruntime
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -172,6 +228,42 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-triton",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --onnx would export model into the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="whether to export fp16 onnx model, default false",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -184,6 +276,391 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
|
||||
for a, b in zip(xlist, blist):
|
||||
try:
|
||||
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||
except AssertionError as error:
|
||||
if tolerate_small_mismatch:
|
||||
print("small mismatch detected", error)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
batch_size = 17
|
||||
seq_len = 101
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
initial_states = [encoder_model.get_init_state() for _ in range(batch_size)]
|
||||
states = stack_states(initial_states)
|
||||
|
||||
left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks
|
||||
encoder_attention_dim = encoder_model.encoders[0].attention_dim
|
||||
|
||||
len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15
|
||||
avg_cache = torch.cat(
|
||||
states[encoder_model.num_encoders : 2 * encoder_model.num_encoders]
|
||||
).transpose(
|
||||
0, 1
|
||||
) # [B,15,384]
|
||||
cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose(
|
||||
0, 1
|
||||
) # [B,2*15,384,cnn_kernel-1]
|
||||
pad_tensors = [
|
||||
torch.nn.functional.pad(
|
||||
tensor,
|
||||
(
|
||||
0,
|
||||
encoder_attention_dim - tensor.shape[-1],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
left_context_len - tensor.shape[1],
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
for tensor in states[
|
||||
2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders
|
||||
]
|
||||
]
|
||||
attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||
|
||||
encoder_model_wrapper = OnnxStreamingEncoder(encoder_model)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model_wrapper,
|
||||
(x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"x",
|
||||
"x_lens",
|
||||
"len_cache",
|
||||
"avg_cache",
|
||||
"attn_cache",
|
||||
"cnn_cache",
|
||||
],
|
||||
output_names=[
|
||||
"encoder_out",
|
||||
"encoder_out_lens",
|
||||
"new_len_cache",
|
||||
"new_avg_cache",
|
||||
"new_attn_cache",
|
||||
"new_cnn_cache",
|
||||
],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
"len_cache": {0: "N"},
|
||||
"avg_cache": {0: "N"},
|
||||
"attn_cache": {0: "N"},
|
||||
"cnn_cache": {0: "N"},
|
||||
"new_len_cache": {0: "N"},
|
||||
"new_avg_cache": {0: "N"},
|
||||
"new_attn_cache": {0: "N"},
|
||||
"new_cnn_cache": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
# Test onnx encoder with torch native encoder
|
||||
encoder_model.eval()
|
||||
(
|
||||
encoder_out_torch,
|
||||
encoder_out_lens_torch,
|
||||
new_states_torch,
|
||||
) = encoder_model.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
ort_session = onnxruntime.InferenceSession(
|
||||
str(encoder_filename), providers=["CPUExecutionProvider"]
|
||||
)
|
||||
ort_inputs = {
|
||||
"x": x.numpy(),
|
||||
"x_lens": x_lens.numpy(),
|
||||
"len_cache": len_cache.numpy(),
|
||||
"avg_cache": avg_cache.numpy(),
|
||||
"attn_cache": attn_cache.numpy(),
|
||||
"cnn_cache": cnn_cache.numpy(),
|
||||
}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
assert test_acc(
|
||||
[encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2]
|
||||
)
|
||||
logging.info(f"{encoder_filename} acc test succeeded.")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx_triton(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
|
||||
decoder_model = TritonOnnxDecoder(decoder_model)
|
||||
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y,),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
|
||||
The exported encoder_proj model has one input:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
The exported decoder_proj model has one input:
|
||||
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
project_input = False
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx_triton(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
and has one output:
|
||||
- joiner_out: a tensor of shape (N, vocab_size)
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
joiner_model = TritonOnnxJoiner(joiner_model)
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
@ -292,7 +769,87 @@ def main():
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
if params.onnx:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
opset_version = 13
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
if not params.onnx_triton:
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
else:
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx_triton(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx_triton(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
if params.fp16:
|
||||
try:
|
||||
import onnxmltools
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
except ImportError:
|
||||
print("Please install onnxmltools!")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
|
||||
encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx"
|
||||
export_onnx_fp16(encoder_filename, encoder_fp16_filename)
|
||||
|
||||
decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx"
|
||||
export_onnx_fp16(decoder_filename, decoder_fp16_filename)
|
||||
|
||||
joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx"
|
||||
export_onnx_fp16(joiner_filename, joiner_fp16_filename)
|
||||
|
||||
if not params.onnx_triton:
|
||||
encoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_encoder_proj.onnx"
|
||||
)
|
||||
encoder_proj_fp16_filename = (
|
||||
params.exp_dir / "joiner_encoder_proj_fp16.onnx"
|
||||
)
|
||||
export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename)
|
||||
|
||||
decoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_decoder_proj.onnx"
|
||||
)
|
||||
decoder_proj_fp16_filename = (
|
||||
params.exp_dir / "joiner_decoder_proj_fp16.onnx"
|
||||
)
|
||||
export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename)
|
||||
|
||||
elif params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
|
@ -0,0 +1,231 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class OnnxStreamingEncoder(torch.nn.Module):
|
||||
"""This class warps the streaming Zipformer to reduce the number of
|
||||
state tensors for onnx.
|
||||
https://github.com/k2-fsa/icefall/pull/831
|
||||
"""
|
||||
|
||||
def __init__(self, encoder):
|
||||
"""
|
||||
Args:
|
||||
encoder: A Instance of Zipformer Class
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = encoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
len_cache: torch.tensor,
|
||||
avg_cache: torch.tensor,
|
||||
attn_cache: torch.tensor,
|
||||
cnn_cache: torch.tensor,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
len_cache:
|
||||
The cached numbers of past frames.
|
||||
avg_cache:
|
||||
The cached average tensors.
|
||||
attn_cache:
|
||||
The cached key tensors of the first attention modules.
|
||||
The cached value tensors of the first attention modules.
|
||||
The cached value tensors of the second attention modules.
|
||||
cnn_cache:
|
||||
The cached left contexts of the first convolution modules.
|
||||
The cached left contexts of the second convolution modules.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
|
||||
"""
|
||||
num_encoder_layers = []
|
||||
encoder_attention_dims = []
|
||||
states = []
|
||||
for i, encoder in enumerate(self.model.encoders):
|
||||
num_encoder_layers.append(encoder.num_layers)
|
||||
encoder_attention_dims.append(encoder.attention_dim)
|
||||
|
||||
len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(len_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
avg_cache = avg_cache.transpose(0, 1) # [15, B, 384]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(avg_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192]
|
||||
left_context_len = attn_cache.shape[1]
|
||||
offset = 0
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[offset : offset + num_layer, : left_context_len // ds]
|
||||
)
|
||||
offset += num_layer
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
encoder_attention_dim = encoder_attention_dims[i]
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[
|
||||
offset : offset + num_layer,
|
||||
: left_context_len // ds,
|
||||
:,
|
||||
: encoder_attention_dim // 2,
|
||||
]
|
||||
)
|
||||
offset += num_layer
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[
|
||||
offset : offset + num_layer,
|
||||
: left_context_len // ds,
|
||||
:,
|
||||
: encoder_attention_dim // 2,
|
||||
]
|
||||
)
|
||||
offset += num_layer
|
||||
|
||||
cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(cnn_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(cnn_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = self.model.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose(
|
||||
0, 1
|
||||
) # [B,15]
|
||||
new_avg_cache = torch.cat(
|
||||
states[self.model.num_encoders : 2 * self.model.num_encoders]
|
||||
).transpose(
|
||||
0, 1
|
||||
) # [B,15,384]
|
||||
new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose(
|
||||
0, 1
|
||||
) # [B,2*15,384,cnn_kernel-1]
|
||||
assert len(set(encoder_attention_dims)) == 1
|
||||
pad_tensors = [
|
||||
torch.nn.functional.pad(
|
||||
tensor,
|
||||
(
|
||||
0,
|
||||
encoder_attention_dims[0] - tensor.shape[-1],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
left_context_len - tensor.shape[1],
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
for tensor in states[
|
||||
2 * self.model.num_encoders : 5 * self.model.num_encoders
|
||||
]
|
||||
]
|
||||
new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||
|
||||
return (
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_len_cache,
|
||||
new_avg_cache,
|
||||
new_attn_cache,
|
||||
new_cnn_cache,
|
||||
)
|
||||
|
||||
|
||||
class TritonOnnxDecoder(torch.nn.Module):
|
||||
"""This class warps the Decoder in decoder.py
|
||||
to remove the scalar input "need_pad".
|
||||
Triton currently doesn't support scalar input.
|
||||
https://github.com/triton-inference-server/server/issues/2333
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
decoder: A instance of Decoder
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = decoder
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
# False to not pad the input. Should be False during inference.
|
||||
need_pad = False
|
||||
return self.model(y, need_pad)
|
||||
|
||||
|
||||
class TritonOnnxJoiner(torch.nn.Module):
|
||||
"""This class warps the Joiner in joiner.py
|
||||
to remove the scalar input "project_input".
|
||||
Triton currently doesn't support scalar input.
|
||||
https://github.com/triton-inference-server/server/issues/2333
|
||||
"project_input" is set to True.
|
||||
Triton solutions only need export joiner to a single joiner.onnx.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
joiner: torch.nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = joiner
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
# Apply input projections encoder_proj and decoder_proj.
|
||||
project_input = True
|
||||
return self.model(encoder_out, decoder_out, project_input)
|
@ -2084,16 +2084,26 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_weights = pos_weights.as_strided(
|
||||
(bsz, num_heads, seq_len, seq_len),
|
||||
(
|
||||
pos_weights.stride(0),
|
||||
pos_weights.stride(1),
|
||||
pos_weights.stride(2) - pos_weights.stride(3),
|
||||
pos_weights.stride(3),
|
||||
),
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1),
|
||||
)
|
||||
if torch.jit.is_tracing():
|
||||
(batch_size, num_heads, time1, n) = pos_weights.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_weights = pos_weights.reshape(-1, n)
|
||||
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
|
||||
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len)
|
||||
else:
|
||||
pos_weights = pos_weights.as_strided(
|
||||
(bsz, num_heads, seq_len, seq_len),
|
||||
(
|
||||
pos_weights.stride(0),
|
||||
pos_weights.stride(1),
|
||||
pos_weights.stride(2) - pos_weights.stride(3),
|
||||
pos_weights.stride(3),
|
||||
),
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1),
|
||||
)
|
||||
|
||||
# caution: they are really scores at this point.
|
||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||
@ -2275,16 +2285,26 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_weights = pos_weights.as_strided(
|
||||
(bsz, num_heads, seq_len, kv_len),
|
||||
(
|
||||
pos_weights.stride(0),
|
||||
pos_weights.stride(1),
|
||||
pos_weights.stride(2) - pos_weights.stride(3),
|
||||
pos_weights.stride(3),
|
||||
),
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1),
|
||||
)
|
||||
if torch.jit.is_tracing():
|
||||
(batch_size, num_heads, time1, n) = pos_weights.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(kv_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_weights = pos_weights.reshape(-1, n)
|
||||
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
|
||||
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len)
|
||||
else:
|
||||
pos_weights = pos_weights.as_strided(
|
||||
(bsz, num_heads, seq_len, kv_len),
|
||||
(
|
||||
pos_weights.stride(0),
|
||||
pos_weights.stride(1),
|
||||
pos_weights.stride(2) - pos_weights.stride(3),
|
||||
pos_weights.stride(3),
|
||||
),
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1),
|
||||
)
|
||||
|
||||
# caution: they are really scores at this point.
|
||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||
|
@ -22,5 +22,6 @@ typeguard==2.13.3
|
||||
multi_quantization
|
||||
|
||||
onnx
|
||||
onnxmltools
|
||||
onnxruntime
|
||||
kaldifst
|
||||
|
Loading…
x
Reference in New Issue
Block a user