mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add recipe for fine-tuning Zipformer with adapter (#1512)
This commit is contained in:
parent
d89f4ea149
commit
7e2b561bbf
@ -35,6 +35,7 @@ The following table lists the differences among them.
|
|||||||
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
|
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
|
||||||
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
||||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
||||||
|
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
|
||||||
|
|
||||||
The decoder in `transducer_stateless` is modified from the paper
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
|
1
egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../tdnn_lstm_ctc/asr_datamodule.py
|
1
egs/librispeech/ASR/zipformer_adapter/beam_search.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/beam_search.py
|
1062
egs/librispeech/ASR/zipformer_adapter/decode.py
Executable file
1062
egs/librispeech/ASR/zipformer_adapter/decode.py
Executable file
File diff suppressed because it is too large
Load Diff
1115
egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py
Executable file
1115
egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/zipformer_adapter/decoder.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/decoder.py
|
1
egs/librispeech/ASR/zipformer_adapter/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../transducer_stateless/encoder_interface.py
|
621
egs/librispeech/ASR/zipformer_adapter/export-onnx.py
Executable file
621
egs/librispeech/ASR/zipformer_adapter/export-onnx.py
Executable file
@ -0,0 +1,621 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
|
||||||
|
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1"
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, add_finetune_arguments, get_model, get_params
|
||||||
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
|
help="Path to the tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
add_finetune_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Zipformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of Zipformer.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||||
|
"""
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2)
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2)
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, x_lens),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens"],
|
||||||
|
output_names=["encoder_out", "encoder_out_lens"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "zipformer2",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "non-streaming zipformer2",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
decoder_model = torch.jit.script(decoder_model)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_embed=model.encoder_embed,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/librispeech/ASR/zipformer_adapter/joiner.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/joiner.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/joiner.py
|
1
egs/librispeech/ASR/zipformer_adapter/model.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/model.py
|
385
egs/librispeech/ASR/zipformer_adapter/onnx_decode.py
Executable file
385
egs/librispeech/ASR/zipformer_adapter/onnx_decode.py
Executable file
@ -0,0 +1,385 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX exported models and uses them to decode the test sets.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--causal False
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
2. Run this file
|
||||||
|
|
||||||
|
./zipformer/onnx_decode.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
|
||||||
|
from onnx_pretrained import greedy_search, OnnxModel
|
||||||
|
|
||||||
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
from k2 import SymbolTable
|
||||||
|
|
||||||
|
conversational_filler = [
|
||||||
|
"UH",
|
||||||
|
"UHH",
|
||||||
|
"UM",
|
||||||
|
"EH",
|
||||||
|
"MM",
|
||||||
|
"HM",
|
||||||
|
"AH",
|
||||||
|
"HUH",
|
||||||
|
"HA",
|
||||||
|
"ER",
|
||||||
|
"OOF",
|
||||||
|
"HEE",
|
||||||
|
"ACH",
|
||||||
|
"EEE",
|
||||||
|
"EW",
|
||||||
|
]
|
||||||
|
unk_tags = ["<UNK>", "<unk>"]
|
||||||
|
gigaspeech_punctuations = [
|
||||||
|
"<COMMA>",
|
||||||
|
"<PERIOD>",
|
||||||
|
"<QUESTIONMARK>",
|
||||||
|
"<EXCLAMATIONPOINT>",
|
||||||
|
]
|
||||||
|
gigaspeech_garbage_utterance_tags = ["<SIL>", "<NOISE>", "<MUSIC>", "<OTHER>"]
|
||||||
|
non_scoring_words = (
|
||||||
|
conversational_filler
|
||||||
|
+ unk_tags
|
||||||
|
+ gigaspeech_punctuations
|
||||||
|
+ gigaspeech_garbage_utterance_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def asr_text_post_processing(text: str) -> str:
|
||||||
|
# 1. convert to uppercase
|
||||||
|
text = text.upper()
|
||||||
|
|
||||||
|
# 2. remove hyphen
|
||||||
|
# "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
|
||||||
|
text = text.replace("-", " ")
|
||||||
|
|
||||||
|
# 3. remove non-scoring words from evaluation
|
||||||
|
remaining_words = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in non_scoring_words:
|
||||||
|
continue
|
||||||
|
remaining_words.append(word)
|
||||||
|
|
||||||
|
return " ".join(remaining_words)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def post_processing(
|
||||||
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
|
) -> List[Tuple[str, List[str], List[str]]]:
|
||||||
|
new_results = []
|
||||||
|
for key, ref, hyp in results:
|
||||||
|
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||||
|
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||||
|
new_results.append((key, new_ref, new_hyp))
|
||||||
|
return new_results
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""Decode one batch and return the result.
|
||||||
|
Currently it only greedy_search is supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
The token table.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||||
|
)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
|
hyps = [token_ids_to_words(h).split() for h in hyps]
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
model: nn.Module,
|
||||||
|
token_table: SymbolTable,
|
||||||
|
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
The token table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A list of tuples. Each tuple contains three elements:
|
||||||
|
- cut_id,
|
||||||
|
- reference transcript,
|
||||||
|
- predicted result.
|
||||||
|
- The total duration (in seconds) of the dataset.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
log_interval = 10
|
||||||
|
total_duration = 0
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||||
|
|
||||||
|
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||||
|
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results.extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
|
||||||
|
return results, total_duration
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
res_dir: Path,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
|
):
|
||||||
|
recog_path = res_dir / f"recogs-{test_set_name}.txt"
|
||||||
|
results = post_processing(results)
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = res_dir / f"errs-{test_set_name}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("WER", file=f)
|
||||||
|
print(wer, file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.decoding_method == "greedy_search"
|
||||||
|
), "Only supports greedy_search currently."
|
||||||
|
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
|
||||||
|
|
||||||
|
setup_logger(f"{res_dir}/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
token_table = SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = OnnxModel(
|
||||||
|
encoder_model_filename=args.encoder_model_filename,
|
||||||
|
decoder_model_filename=args.decoder_model_filename,
|
||||||
|
joiner_model_filename=args.joiner_model_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
|
||||||
|
gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
|
||||||
|
|
||||||
|
dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
|
||||||
|
test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["dev", "test"]
|
||||||
|
test_dl = [dev_dl, test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
start_time = time.time()
|
||||||
|
results, total_duration = decode_dataset(
|
||||||
|
dl=test_dl, model=model, token_table=token_table
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
rtf = elapsed_seconds / total_duration
|
||||||
|
|
||||||
|
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
|
logging.info(f"Wave duration: {total_duration:.3f} s")
|
||||||
|
logging.info(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/onnx_pretrained.py
|
1
egs/librispeech/ASR/zipformer_adapter/optim.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/optim.py
|
1
egs/librispeech/ASR/zipformer_adapter/scaling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/scaling.py
|
1
egs/librispeech/ASR/zipformer_adapter/scaling_converter.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/scaling_converter.py
|
1
egs/librispeech/ASR/zipformer_adapter/subsampling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adapter/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/subsampling.py
|
1541
egs/librispeech/ASR/zipformer_adapter/train.py
Executable file
1541
egs/librispeech/ASR/zipformer_adapter/train.py
Executable file
File diff suppressed because it is too large
Load Diff
2515
egs/librispeech/ASR/zipformer_adapter/zipformer.py
Normal file
2515
egs/librispeech/ASR/zipformer_adapter/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user