mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
add the pruned_transducer_stateless7_streaming recipe for commonvoice
This commit is contained in:
parent
d67a49afe4
commit
f31d31ff1c
@ -26,7 +26,7 @@ repos:
|
||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.10.1
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile=black"]
|
||||
|
166
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
166
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
141
egs/commonvoice/ASR/local/compile_lg.py
Executable file
141
egs/commonvoice/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates LG from
|
||||
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated LG is saved in $lang_dir/LG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
|
||||
Return:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/G_3_gram.pt").is_file():
|
||||
logging.info("Loading pre-compiled G_3_gram")
|
||||
d = torch.load(f"{lang_dir}/lm/G_3_gram.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open(f"{lang_dir}/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/G_3_gram.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
# assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
# LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
if isinstance(LG.aux_labels, k2.RaggedTensor):
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
else:
|
||||
LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
return LG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "LG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -56,8 +56,8 @@ def get_args():
|
||||
def compute_fbank_commonvoice_dev_test(language: str):
|
||||
src_dir = Path(f"data/{language}/manifests")
|
||||
output_dir = Path(f"data/{language}/fbank")
|
||||
num_workers = 42
|
||||
batch_duration = 600
|
||||
num_workers = 16
|
||||
batch_duration = 200
|
||||
|
||||
subsets = ("dev", "test")
|
||||
|
||||
|
@ -43,9 +43,14 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def normalize_text(utt: str) -> str:
|
||||
def normalize_text(utt: str, language: str) -> str:
|
||||
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
utt = re.sub("’", "'", utt)
|
||||
if language == "en":
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
if language == "fr":
|
||||
utt = utt.upper()
|
||||
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt)
|
||||
|
||||
|
||||
def preprocess_commonvoice(
|
||||
@ -94,7 +99,7 @@ def preprocess_commonvoice(
|
||||
for sup in m["supervisions"]:
|
||||
text = str(sup.text)
|
||||
orig_text = text
|
||||
sup.text = normalize_text(sup.text)
|
||||
sup.text = normalize_text(sup.text, language)
|
||||
text = str(sup.text)
|
||||
if len(orig_text) != len(text):
|
||||
logging.info(
|
||||
|
@ -36,8 +36,8 @@ num_splits=1000
|
||||
# - speech
|
||||
|
||||
dl_dir=$PWD/download
|
||||
release=cv-corpus-13.0-2023-03-09
|
||||
lang=en
|
||||
release=cv-corpus-12.0-2022-12-07
|
||||
lang=fr
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
@ -138,7 +138,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
|
||||
./local/compute_fbank_commonvoice_splits.py \
|
||||
--num-workers $nj \
|
||||
--batch-duration 600 \
|
||||
--batch-duration 200 \
|
||||
--start 0 \
|
||||
--num-splits $num_splits \
|
||||
--language $lang
|
||||
@ -181,7 +181,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
|
||||
sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/words.txt ]; then
|
||||
cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' > $lang_dir/words.txt
|
||||
@ -208,14 +208,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
}' > $lang_dir/words || exit 1;
|
||||
mv $lang_dir/words $lang_dir/words.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
|
||||
@ -242,3 +242,55 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/${lang}/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir/lm
|
||||
#3-gram used in building HLG, 4-gram used for LM rescoring
|
||||
for ngram in 3 4; do
|
||||
if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order ${ngram} \
|
||||
-text $lang_dir/transcript_words.txt \
|
||||
-lm $lang_dir/lm/${ngram}gram.arpa
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=${ngram} \
|
||||
$lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt
|
||||
fi
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Compile HLG"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/${lang}/lang_bpe_${vocab_size}
|
||||
./local/compile_hlg.py --lang-dir $lang_dir
|
||||
|
||||
# Note If ./local/compile_hlg.py throws OOM,
|
||||
# please switch to the following command
|
||||
#
|
||||
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
# Compile LG for RNN-T fast_beam_search decoding
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 12: Compile LG"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/${lang}/lang_bpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
@ -0,0 +1,9 @@
|
||||
This recipe implements Streaming Zipformer-Transducer model.
|
||||
|
||||
See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
|
||||
|
||||
[./emformer.py](./emformer.py) and [./train.py](./train.py)
|
||||
are basically the same as
|
||||
[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
|
||||
The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
|
||||
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py
|
@ -0,0 +1,422 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class CommonVoiceAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default="fr",
|
||||
help="""Language of Common Voice""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--cv-manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fr/fbank"),
|
||||
help="Path to directory with CommonVoice train/dev/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
|
||||
)
|
810
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
810
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
@ -0,0 +1,810 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from commonvoice_fr import CommonVoiceAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += 30
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, 30),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
|
||||
# errs_info = (
|
||||
# params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
# )
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{key}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
CommonVoiceAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
commonvoice = CommonVoiceAsrDataModule(args)
|
||||
|
||||
test_cuts = commonvoice.test_cuts()
|
||||
|
||||
test_dl = commonvoice.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = "test-cv"
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_sets,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,151 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
class DecodeStream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
cut_id: str,
|
||||
initial_states: List[torch.Tensor],
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
initial_states:
|
||||
Initial decode states of the model, e.g. the return value of
|
||||
`get_init_state` in conformer.py
|
||||
decoding_graph:
|
||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||
Used only when decoding_method is fast_beam_search.
|
||||
device:
|
||||
The device to run this stream.
|
||||
"""
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
assert decoding_graph is not None
|
||||
assert device == decoding_graph.device
|
||||
|
||||
self.params = params
|
||||
self.cut_id = cut_id
|
||||
self.LOG_EPS = math.log(1e-10)
|
||||
|
||||
self.states = initial_states
|
||||
|
||||
# It contains a 2-D tensors representing the feature frames.
|
||||
self.features: torch.Tensor = None
|
||||
|
||||
self.num_frames: int = 0
|
||||
# how many frames have been processed. (before subsampling).
|
||||
# we only modify this value in `func:get_feature_frames`.
|
||||
self.num_processed_frames: int = 0
|
||||
|
||||
self._done: bool = False
|
||||
|
||||
# The transcript of current utterance.
|
||||
self.ground_truth: str = ""
|
||||
|
||||
# The decoding result (partial or final) of current utterance.
|
||||
self.hyp: List = []
|
||||
|
||||
# how many frames have been processed, after subsampling (i.e. a
|
||||
# cumulative sum of the second return value of
|
||||
# encoder.streaming_forward
|
||||
self.done_frames: int = 0
|
||||
|
||||
# It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2
|
||||
# 1) feature embedding: out_lens=(x_lens-7)//2
|
||||
# 2) output subsampling: out_lens=(out_lens+1)//2
|
||||
self.pad_length = 7
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
self.hyps = HypothesisList()
|
||||
self.hyps.add(
|
||||
Hypothesis(
|
||||
ys=[params.blank_id] * params.context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
|
||||
decoding_graph
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all the features are processed."""
|
||||
return self._done
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.cut_id
|
||||
|
||||
def set_features(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
tail_pad_len: int = 0,
|
||||
) -> None:
|
||||
"""Set features tensor of current utterance."""
|
||||
assert features.dim() == 2, features.dim()
|
||||
self.features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, self.pad_length + tail_pad_len),
|
||||
mode="constant",
|
||||
value=self.LOG_EPS,
|
||||
)
|
||||
self.num_frames = self.features.size(0)
|
||||
|
||||
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Consume chunk_size frames of features"""
|
||||
chunk_length = chunk_size + self.pad_length
|
||||
|
||||
ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
|
||||
|
||||
ret_features = self.features[
|
||||
self.num_processed_frames : self.num_processed_frames + ret_length # noqa
|
||||
]
|
||||
|
||||
self.num_processed_frames += chunk_size
|
||||
if self.num_processed_frames >= self.num_frames:
|
||||
self._done = True
|
||||
|
||||
return ret_features, ret_length
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.params.decoding_method == "greedy_search":
|
||||
return self.hyp[self.params.context_size :] # noqa
|
||||
elif self.params.decoding_method == "modified_beam_search":
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.params.context_size :] # noqa
|
||||
else:
|
||||
assert self.params.decoding_method == "fast_beam_search"
|
||||
return self.hyp
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/decoder.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
|
367
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
Executable file
367
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
Executable file
@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Please see
|
||||
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||
for more details about how to use this file.
|
||||
|
||||
We use
|
||||
https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||
to demonstrate the usage of this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_char_bpe/L.pt"
|
||||
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
|
||||
git lfs pull --include "data/lang_char_bpe/Linv.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export to ncnn
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||
--lang-dir $repo/data/lang_char_bpe \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--num-encoder-layers "2,4,3,2,4" \
|
||||
--feedforward-dims "1024,1024,1536,1536,1024" \
|
||||
--nhead "8,8,8,8,8" \
|
||||
--encoder-dims "384,384,384,384,384" \
|
||||
--attention-dims "192,192,192,192,192" \
|
||||
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||
--cnn-module-kernels "31,31,31,31,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
|
||||
cd $repo/exp
|
||||
|
||||
pnnx encoder_jit_trace-pnnx.pt
|
||||
pnnx decoder_jit_trace-pnnx.pt
|
||||
pnnx joiner_jit_trace-pnnx.pt
|
||||
|
||||
You can find converted models at
|
||||
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
|
||||
|
||||
See ./streaming-ncnn-decode.py
|
||||
and
|
||||
https://github.com/k2-fsa/sherpa-ncnn
|
||||
for usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||
|
||||
decode_chunk_len = encoder_model.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
T = decode_chunk_len + pad_length # 32 + 7 = 39
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
states = encoder_model.get_init_state()
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
main()
|
369
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
Executable file
369
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
Executable file
@ -0,0 +1,369 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Please see
|
||||
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||
for more details about how to use this file.
|
||||
|
||||
We use
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
to demonstrate the usage of this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export to ncnn
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
\
|
||||
--decode-chunk-len 32 \
|
||||
--num-encoder-layers "2,4,3,2,4" \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--nhead "8,8,8,8,8" \
|
||||
--encoder-dims "384,384,384,384,384" \
|
||||
--attention-dims "192,192,192,192,192" \
|
||||
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||
--cnn-module-kernels "31,31,31,31,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
|
||||
cd $repo/exp
|
||||
|
||||
pnnx encoder_jit_trace-pnnx.pt
|
||||
pnnx decoder_jit_trace-pnnx.pt
|
||||
pnnx joiner_jit_trace-pnnx.pt
|
||||
|
||||
You can find converted models at
|
||||
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
|
||||
|
||||
See ./streaming-ncnn-decode.py
|
||||
and
|
||||
https://github.com/k2-fsa/sherpa-ncnn
|
||||
for usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||
|
||||
decode_chunk_len = encoder_model.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
T = decode_chunk_len + pad_length # 32 + 7 = 39
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
states = encoder_model.get_init_state()
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
main()
|
647
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
Executable file
647
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
Executable file
@ -0,0 +1,647 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
1. Download the pre-trained model
|
||||
cd egs/librispeech/ASR
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
2. Export the model to ONNX
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
It will generate the following 3 files in $repo/exp
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
See ./onnx_pretrained.py for how to use the exported models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
|
||||
"""Please see the help information of Zipformer.streaming_forward"""
|
||||
N = x.size(0)
|
||||
T = x.size(1)
|
||||
x_lens = torch.tensor([T] * N, device=x.device)
|
||||
|
||||
output, _, new_states = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
output = self.encoder_proj(output)
|
||||
# Now output is of shape (N, T, joiner_dim)
|
||||
|
||||
return output, new_states
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""
|
||||
Onnx model inputs:
|
||||
- 0: src
|
||||
- many state tensors (the exact number depending on the actual model)
|
||||
Onnx model outputs:
|
||||
- 0: output, its shape is (N, T, joiner_dim)
|
||||
- many state tensors (the exact number depending on the actual model)
|
||||
Args:
|
||||
encoder_model:
|
||||
The model to be exported
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
)
|
||||
|
||||
decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
T = decode_chunk_len + pad_length
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"pad_length: {pad_length}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
|
||||
init_state = encoder_model.encoder.get_init_state()
|
||||
|
||||
num_encoders = encoder_model.encoder.num_encoders
|
||||
logging.info(f"num_encoders: {num_encoders}")
|
||||
logging.info(f"len(init_state): {len(init_state)}")
|
||||
|
||||
inputs = {}
|
||||
input_names = ["x"]
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, name, N):
|
||||
for i, s in enumerate(tensors):
|
||||
logging.info(f"{name}_{i}.shape: {s.shape}")
|
||||
inputs[f"{name}_{i}"] = {N: "N"}
|
||||
outputs[f"new_{name}_{i}"] = {N: "N"}
|
||||
input_names.append(f"{name}_{i}")
|
||||
output_names.append(f"new_{name}_{i}")
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims))
|
||||
attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims))
|
||||
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels))
|
||||
ds = encoder_model.encoder.zipformer_downsampling_factors
|
||||
left_context_len = encoder_model.encoder.left_context_len
|
||||
left_context_len = [left_context_len // k for k in ds]
|
||||
left_context_len = ",".join(map(str, left_context_len))
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||
"T": str(T), # 39
|
||||
"num_encoder_layers": num_encoder_layers,
|
||||
"encoder_dims": encoder_dims,
|
||||
"attention_dims": attention_dims,
|
||||
"cnn_module_kernels": cnn_module_kernels,
|
||||
"left_context_len": left_context_len,
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
# (num_encoder_layers, 1)
|
||||
cached_len = init_state[num_encoders * 0 : num_encoders * 1]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim)
|
||||
cached_avg = init_state[num_encoders * 1 : num_encoders * 2]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim)
|
||||
cached_key = init_state[num_encoders * 2 : num_encoders * 3]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val = init_state[num_encoders * 3 : num_encoders * 4]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val2 = init_state[num_encoders * 4 : num_encoders * 5]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7]
|
||||
|
||||
build_inputs_outputs(cached_len, "cached_len", 1)
|
||||
build_inputs_outputs(cached_avg, "cached_avg", 1)
|
||||
build_inputs_outputs(cached_key, "cached_key", 2)
|
||||
build_inputs_outputs(cached_val, "cached_val", 2)
|
||||
build_inputs_outputs(cached_val2, "cached_val2", 2)
|
||||
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
|
||||
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
|
||||
|
||||
logging.info(inputs)
|
||||
logging.info(outputs)
|
||||
logging.info(input_names)
|
||||
logging.info(output_names)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
"x": {0: "N"},
|
||||
"encoder_out": {0: "N"},
|
||||
**inputs,
|
||||
**outputs,
|
||||
},
|
||||
)
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
The exported model has one input:
|
||||
- y: a torch.int64 tensor of shape (N, context_size)
|
||||
and has one output:
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
Note: The argument need_pad is fixed to False.
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
and produces one output:
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
if params.use_averaged_model:
|
||||
suffix += "-with-averaged-model"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
878
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py
Executable file
878
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py
Executable file
@ -0,0 +1,878 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
||||
|
||||
(3) Export to ONNX format with pretrained.pt
|
||||
|
||||
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--fp16 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(4) Export to ONNX format for triton server
|
||||
|
||||
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--fp16 \
|
||||
--onnx-triton 1 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa/tree/master/triton
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnxruntime
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-triton",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --onnx would export model into the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="whether to export fp16 onnx model, default false",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
|
||||
for a, b in zip(xlist, blist):
|
||||
try:
|
||||
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||
except AssertionError as error:
|
||||
if tolerate_small_mismatch:
|
||||
print("small mismatch detected", error)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
batch_size = 17
|
||||
seq_len = 101
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
initial_states = [encoder_model.get_init_state() for _ in range(batch_size)]
|
||||
states = stack_states(initial_states)
|
||||
|
||||
left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks
|
||||
encoder_attention_dim = encoder_model.encoders[0].attention_dim
|
||||
|
||||
len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15
|
||||
avg_cache = torch.cat(
|
||||
states[encoder_model.num_encoders : 2 * encoder_model.num_encoders]
|
||||
).transpose(
|
||||
0, 1
|
||||
) # [B,15,384]
|
||||
cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose(
|
||||
0, 1
|
||||
) # [B,2*15,384,cnn_kernel-1]
|
||||
pad_tensors = [
|
||||
torch.nn.functional.pad(
|
||||
tensor,
|
||||
(
|
||||
0,
|
||||
encoder_attention_dim - tensor.shape[-1],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
left_context_len - tensor.shape[1],
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
for tensor in states[
|
||||
2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders
|
||||
]
|
||||
]
|
||||
attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||
|
||||
encoder_model_wrapper = OnnxStreamingEncoder(encoder_model)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model_wrapper,
|
||||
(x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"x",
|
||||
"x_lens",
|
||||
"len_cache",
|
||||
"avg_cache",
|
||||
"attn_cache",
|
||||
"cnn_cache",
|
||||
],
|
||||
output_names=[
|
||||
"encoder_out",
|
||||
"encoder_out_lens",
|
||||
"new_len_cache",
|
||||
"new_avg_cache",
|
||||
"new_attn_cache",
|
||||
"new_cnn_cache",
|
||||
],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
"len_cache": {0: "N"},
|
||||
"avg_cache": {0: "N"},
|
||||
"attn_cache": {0: "N"},
|
||||
"cnn_cache": {0: "N"},
|
||||
"new_len_cache": {0: "N"},
|
||||
"new_avg_cache": {0: "N"},
|
||||
"new_attn_cache": {0: "N"},
|
||||
"new_cnn_cache": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
# Test onnx encoder with torch native encoder
|
||||
encoder_model.eval()
|
||||
(
|
||||
encoder_out_torch,
|
||||
encoder_out_lens_torch,
|
||||
new_states_torch,
|
||||
) = encoder_model.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
ort_session = onnxruntime.InferenceSession(
|
||||
str(encoder_filename), providers=["CPUExecutionProvider"]
|
||||
)
|
||||
ort_inputs = {
|
||||
"x": x.numpy(),
|
||||
"x_lens": x_lens.numpy(),
|
||||
"len_cache": len_cache.numpy(),
|
||||
"avg_cache": avg_cache.numpy(),
|
||||
"attn_cache": attn_cache.numpy(),
|
||||
"cnn_cache": cnn_cache.numpy(),
|
||||
}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
assert test_acc(
|
||||
[encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2]
|
||||
)
|
||||
logging.info(f"{encoder_filename} acc test succeeded.")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx_triton(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
|
||||
decoder_model = TritonOnnxDecoder(decoder_model)
|
||||
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y,),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
|
||||
The exported encoder_proj model has one input:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
The exported decoder_proj model has one input:
|
||||
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
project_input = False
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx_triton(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
and has one output:
|
||||
- joiner_out: a tensor of shape (N, vocab_size)
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
joiner_model = TritonOnnxJoiner(joiner_model)
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.onnx:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
opset_version = 13
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
if not params.onnx_triton:
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
else:
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx_triton(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx_triton(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
if params.fp16:
|
||||
try:
|
||||
import onnxmltools
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
except ImportError:
|
||||
print("Please install onnxmltools!")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
|
||||
encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx"
|
||||
export_onnx_fp16(encoder_filename, encoder_fp16_filename)
|
||||
|
||||
decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx"
|
||||
export_onnx_fp16(decoder_filename, decoder_fp16_filename)
|
||||
|
||||
joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx"
|
||||
export_onnx_fp16(joiner_filename, joiner_fp16_filename)
|
||||
|
||||
if not params.onnx_triton:
|
||||
encoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_encoder_proj.onnx"
|
||||
)
|
||||
encoder_proj_fp16_filename = (
|
||||
params.exp_dir / "joiner_encoder_proj_fp16.onnx"
|
||||
)
|
||||
export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename)
|
||||
|
||||
decoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_decoder_proj.onnx"
|
||||
)
|
||||
decoder_proj_fp16_filename = (
|
||||
params.exp_dir / "joiner_decoder_proj_fp16.onnx"
|
||||
)
|
||||
export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename)
|
||||
|
||||
elif params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1342
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
Executable file
1342
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--use-averaged-model True \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
|
||||
|
||||
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--iter 22000 \
|
||||
--avg 5 \
|
||||
--use-averaged-model True \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
|
||||
|
||||
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
|
||||
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--iter 22000 \
|
||||
--avg 5 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model."
|
||||
"If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
print("Script started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
print(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
print("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
print(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
print(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
print(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
print(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
278
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
Executable file
278
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
Executable file
@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7_streaming/exp/cpu_jit.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-len",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The chunk size for decoding (in frames before subsampling)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
model.encoder.decode_chunk_size = args.decode_chunk_len // 2
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
313
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
Executable file
313
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
Executable file
@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--decode-chunk-len 32
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
params: AttributeDict,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
decode_chunk_len = params.decode_chunk_len # before subsampling
|
||||
pad_length = 7
|
||||
s = f"decode_chunk_len: {decode_chunk_len}"
|
||||
logging.info(s)
|
||||
assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
|
||||
encoder_model.decode_chunk_size,
|
||||
decode_chunk_len,
|
||||
)
|
||||
|
||||
T = decode_chunk_len + pad_length
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
states = encoder_model.get_init_state(device=x.device)
|
||||
|
||||
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models exported by `torch.jit.trace()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--decode-chunk-len 32
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--decode-chunk-len 32 \
|
||||
/path/to/foo.wav \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-len",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The chunk size for decoding (in frames before subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
assert encoder_out.ndim == 2
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
|
||||
# decoder_input.shape (1,, 1 context_size)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
else:
|
||||
assert decoder_out.ndim == 2
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for i in range(T):
|
||||
cur_encoder_out = encoder_out[i : i + 1]
|
||||
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
|
||||
decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
encoder = torch.jit.load(args.encoder_model_filename)
|
||||
decoder = torch.jit.load(args.decoder_model_filename)
|
||||
joiner = torch.jit.load(args.joiner_model_filename)
|
||||
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
joiner.eval()
|
||||
|
||||
encoder.to(device)
|
||||
decoder.to(device)
|
||||
joiner.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor(args.sample_rate)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
logging.info("Decoding started")
|
||||
chunk_length = args.decode_chunk_len
|
||||
assert encoder.decode_chunk_size == chunk_length // 2, (
|
||||
encoder.decode_chunk_size,
|
||||
chunk_length,
|
||||
)
|
||||
|
||||
# we subsample features with ((x_len - 7) // 2 + 1) // 2
|
||||
pad_length = 7
|
||||
T = chunk_length + pad_length
|
||||
|
||||
logging.info(f"chunk_length: {chunk_length}")
|
||||
|
||||
states = encoder.get_init_state(device)
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
chunk = int(0.25 * args.sample_rate) # 0.2 second
|
||||
num_processed_frames = 0
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
logging.info(f"{start}/{wave_samples.numel()}")
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=args.sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= T:
|
||||
frames = []
|
||||
for i in range(T):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
frames = torch.cat(frames, dim=0).unsqueeze(0)
|
||||
x_lens = torch.tensor([T], dtype=torch.int32)
|
||||
encoder_out, out_lens, states = encoder(
|
||||
x=frames,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
num_processed_frames += chunk_length
|
||||
|
||||
hyp, decoder_out = greedy_search(
|
||||
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
|
||||
)
|
||||
|
||||
context_size = 2
|
||||
logging.info(args.sound_file)
|
||||
logging.info(sp.decode(hyp[context_size:]))
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/joiner.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/model.py
|
260
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
Executable file
260
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
Executable file
@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script checks that exported ONNX models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model via torch.jit.trace()
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp
|
||||
|
||||
- encoder_jit_trace.pt
|
||||
- decoder_jit_trace.pt
|
||||
- joiner_jit_trace.pt
|
||||
|
||||
3. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
4. Run this file
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_check.py \
|
||||
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
from zipformer import stack_states
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX joiner model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
torch_encoder_model: torch.jit.ScriptModule,
|
||||
torch_encoder_proj_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
T = onnx_model.segment
|
||||
C = 80
|
||||
x_lens = torch.tensor([T] * N)
|
||||
torch_states = [torch_encoder_model.get_init_state() for _ in range(N)]
|
||||
torch_states = stack_states(torch_states)
|
||||
|
||||
onnx_model.init_encoder_states(N)
|
||||
|
||||
for i in range(5):
|
||||
logging.info(f"test_encoder: iter {i}")
|
||||
x = torch.rand(N, T, C)
|
||||
torch_encoder_out, _, torch_states = torch_encoder_model(
|
||||
x, x_lens, torch_states
|
||||
)
|
||||
torch_encoder_out = torch_encoder_proj_model(torch_encoder_out)
|
||||
|
||||
onnx_encoder_out = onnx_model.run_encoder(x)
|
||||
|
||||
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), (
|
||||
(torch_encoder_out - onnx_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
torch_decoder_model: torch.jit.ScriptModule,
|
||||
torch_decoder_proj_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
context_size = onnx_model.context_size
|
||||
vocab_size = onnx_model.vocab_size
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_decoder: iter {i}, N={N}")
|
||||
x = torch.randint(
|
||||
low=1,
|
||||
high=vocab_size,
|
||||
size=(N, context_size),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False]))
|
||||
torch_decoder_out = torch_decoder_proj_model(torch_decoder_out)
|
||||
torch_decoder_out = torch_decoder_out.squeeze(1)
|
||||
|
||||
onnx_decoder_out = onnx_model.run_decoder(x)
|
||||
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
|
||||
(torch_decoder_out - onnx_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
torch_joiner_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1]
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_joiner: iter {i}, N={N}")
|
||||
encoder_out = torch.rand(N, encoder_dim)
|
||||
decoder_out = torch.rand(N, decoder_dim)
|
||||
|
||||
projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out)
|
||||
projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out)
|
||||
|
||||
torch_joiner_out = torch_joiner_model(encoder_out, decoder_out)
|
||||
onnx_joiner_out = onnx_model.run_joiner(
|
||||
projected_encoder_out, projected_decoder_out
|
||||
)
|
||||
|
||||
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
|
||||
(torch_joiner_out - onnx_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_encoder_model = torch.jit.load(args.jit_encoder_filename)
|
||||
torch_decoder_model = torch.jit.load(args.jit_decoder_filename)
|
||||
torch_joiner_model = torch.jit.load(args.jit_joiner_filename)
|
||||
|
||||
onnx_model = OnnxModel(
|
||||
encoder_model_filename=args.onnx_encoder_filename,
|
||||
decoder_model_filename=args.onnx_decoder_filename,
|
||||
joiner_model_filename=args.onnx_joiner_filename,
|
||||
)
|
||||
|
||||
logging.info("Test encoder")
|
||||
# When exporting the model to onnx, we have already put the encoder_proj
|
||||
# inside the encoder.
|
||||
test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model)
|
||||
|
||||
logging.info("Test decoder")
|
||||
# When exporting the model to onnx, we have already put the decoder_proj
|
||||
# inside the decoder.
|
||||
test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model)
|
||||
|
||||
logging.info("Test joiner")
|
||||
test_joiner(torch_joiner_model, onnx_model)
|
||||
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/38342
|
||||
# and https://github.com/pytorch/pytorch/issues/33354
|
||||
#
|
||||
# If we don't do this, the delay increases whenever there is
|
||||
# a new request that changes the actual batch size.
|
||||
# If you use `py-spy dump --pid <server-pid> --native`, you will
|
||||
# see a lot of time is spent in re-compiling the torch script model.
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20230207)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1,231 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class OnnxStreamingEncoder(torch.nn.Module):
|
||||
"""This class warps the streaming Zipformer to reduce the number of
|
||||
state tensors for onnx.
|
||||
https://github.com/k2-fsa/icefall/pull/831
|
||||
"""
|
||||
|
||||
def __init__(self, encoder):
|
||||
"""
|
||||
Args:
|
||||
encoder: An instance of Zipformer Class
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = encoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
len_cache: torch.tensor,
|
||||
avg_cache: torch.tensor,
|
||||
attn_cache: torch.tensor,
|
||||
cnn_cache: torch.tensor,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
len_cache:
|
||||
The cached numbers of past frames.
|
||||
avg_cache:
|
||||
The cached average tensors.
|
||||
attn_cache:
|
||||
The cached key tensors of the first attention modules.
|
||||
The cached value tensors of the first attention modules.
|
||||
The cached value tensors of the second attention modules.
|
||||
cnn_cache:
|
||||
The cached left contexts of the first convolution modules.
|
||||
The cached left contexts of the second convolution modules.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
|
||||
"""
|
||||
num_encoder_layers = []
|
||||
encoder_attention_dims = []
|
||||
states = []
|
||||
for i, encoder in enumerate(self.model.encoders):
|
||||
num_encoder_layers.append(encoder.num_layers)
|
||||
encoder_attention_dims.append(encoder.attention_dim)
|
||||
|
||||
len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(len_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
avg_cache = avg_cache.transpose(0, 1) # [15, B, 384]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(avg_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192]
|
||||
left_context_len = attn_cache.shape[1]
|
||||
offset = 0
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[offset : offset + num_layer, : left_context_len // ds]
|
||||
)
|
||||
offset += num_layer
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
encoder_attention_dim = encoder_attention_dims[i]
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[
|
||||
offset : offset + num_layer,
|
||||
: left_context_len // ds,
|
||||
:,
|
||||
: encoder_attention_dim // 2,
|
||||
]
|
||||
)
|
||||
offset += num_layer
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[
|
||||
offset : offset + num_layer,
|
||||
: left_context_len // ds,
|
||||
:,
|
||||
: encoder_attention_dim // 2,
|
||||
]
|
||||
)
|
||||
offset += num_layer
|
||||
|
||||
cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(cnn_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(cnn_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = self.model.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose(
|
||||
0, 1
|
||||
) # [B,15]
|
||||
new_avg_cache = torch.cat(
|
||||
states[self.model.num_encoders : 2 * self.model.num_encoders]
|
||||
).transpose(
|
||||
0, 1
|
||||
) # [B,15,384]
|
||||
new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose(
|
||||
0, 1
|
||||
) # [B,2*15,384,cnn_kernel-1]
|
||||
assert len(set(encoder_attention_dims)) == 1
|
||||
pad_tensors = [
|
||||
torch.nn.functional.pad(
|
||||
tensor,
|
||||
(
|
||||
0,
|
||||
encoder_attention_dims[0] - tensor.shape[-1],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
left_context_len - tensor.shape[1],
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
for tensor in states[
|
||||
2 * self.model.num_encoders : 5 * self.model.num_encoders
|
||||
]
|
||||
]
|
||||
new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||
|
||||
return (
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_len_cache,
|
||||
new_avg_cache,
|
||||
new_attn_cache,
|
||||
new_cnn_cache,
|
||||
)
|
||||
|
||||
|
||||
class TritonOnnxDecoder(torch.nn.Module):
|
||||
"""This class warps the Decoder in decoder.py
|
||||
to remove the scalar input "need_pad".
|
||||
Triton currently doesn't support scalar input.
|
||||
https://github.com/triton-inference-server/server/issues/2333
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
decoder: A instance of Decoder
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = decoder
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
# False to not pad the input. Should be False during inference.
|
||||
need_pad = False
|
||||
return self.model(y, need_pad)
|
||||
|
||||
|
||||
class TritonOnnxJoiner(torch.nn.Module):
|
||||
"""This class warps the Joiner in joiner.py
|
||||
to remove the scalar input "project_input".
|
||||
Triton currently doesn't support scalar input.
|
||||
https://github.com/triton-inference-server/server/issues/2333
|
||||
"project_input" is set to True.
|
||||
Triton solutions only need export joiner to a single joiner.onnx.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
joiner: torch.nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = joiner
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
# Apply input projections encoder_proj and decoder_proj.
|
||||
project_input = True
|
||||
return self.model(encoder_out, decoder_out, project_input)
|
512
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
Executable file
512
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
Executable file
@ -0,0 +1,512 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx.py
|
||||
and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files in $repo/exp
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file with the exported ONNX models
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
the exported ONNX models do support batch processing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
def init_encoder_states(self, batch_size: int = 1):
|
||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
|
||||
model_type = encoder_meta["model_type"]
|
||||
assert model_type == "zipformer", model_type
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
T = int(encoder_meta["T"])
|
||||
|
||||
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||
encoder_dims = encoder_meta["encoder_dims"]
|
||||
attention_dims = encoder_meta["attention_dims"]
|
||||
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
|
||||
left_context_len = encoder_meta["left_context_len"]
|
||||
|
||||
def to_int_list(s):
|
||||
return list(map(int, s.split(",")))
|
||||
|
||||
num_encoder_layers = to_int_list(num_encoder_layers)
|
||||
encoder_dims = to_int_list(encoder_dims)
|
||||
attention_dims = to_int_list(attention_dims)
|
||||
cnn_module_kernels = to_int_list(cnn_module_kernels)
|
||||
left_context_len = to_int_list(left_context_len)
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||
logging.info(f"encoder_dims: {encoder_dims}")
|
||||
logging.info(f"attention_dims: {attention_dims}")
|
||||
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
|
||||
logging.info(f"left_context_len: {left_context_len}")
|
||||
|
||||
num_encoders = len(num_encoder_layers)
|
||||
|
||||
cached_len = []
|
||||
cached_avg = []
|
||||
cached_key = []
|
||||
cached_val = []
|
||||
cached_val2 = []
|
||||
cached_conv1 = []
|
||||
cached_conv2 = []
|
||||
|
||||
N = batch_size
|
||||
|
||||
for i in range(num_encoders):
|
||||
cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64))
|
||||
cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i]))
|
||||
cached_key.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], left_context_len[i], N, attention_dims[i]
|
||||
)
|
||||
)
|
||||
cached_val.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i],
|
||||
left_context_len[i],
|
||||
N,
|
||||
attention_dims[i] // 2,
|
||||
)
|
||||
)
|
||||
cached_val2.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i],
|
||||
left_context_len[i],
|
||||
N,
|
||||
attention_dims[i] // 2,
|
||||
)
|
||||
)
|
||||
cached_conv1.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
|
||||
)
|
||||
)
|
||||
cached_conv2.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
|
||||
)
|
||||
)
|
||||
|
||||
self.cached_len = cached_len
|
||||
self.cached_avg = cached_avg
|
||||
self.cached_key = cached_key
|
||||
self.cached_val = cached_val
|
||||
self.cached_val2 = cached_val2
|
||||
self.cached_conv1 = cached_conv1
|
||||
self.cached_conv2 = cached_conv2
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
|
||||
self.segment = T
|
||||
self.offset = decode_chunk_len
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def _build_encoder_input_output(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
encoder_input = {"x": x.numpy()}
|
||||
encoder_output = ["encoder_out"]
|
||||
|
||||
def build_states_input(states: List[torch.Tensor], name: str):
|
||||
for i, s in enumerate(states):
|
||||
if isinstance(s, torch.Tensor):
|
||||
encoder_input[f"{name}_{i}"] = s.numpy()
|
||||
else:
|
||||
encoder_input[f"{name}_{i}"] = s
|
||||
|
||||
encoder_output.append(f"new_{name}_{i}")
|
||||
|
||||
build_states_input(self.cached_len, "cached_len")
|
||||
build_states_input(self.cached_avg, "cached_avg")
|
||||
build_states_input(self.cached_key, "cached_key")
|
||||
build_states_input(self.cached_val, "cached_val")
|
||||
build_states_input(self.cached_val2, "cached_val2")
|
||||
build_states_input(self.cached_conv1, "cached_conv1")
|
||||
build_states_input(self.cached_conv2, "cached_conv2")
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
num_encoders = self.num_encoders
|
||||
|
||||
self.cached_len = states[num_encoders * 0 : num_encoders * 1]
|
||||
self.cached_avg = states[num_encoders * 1 : num_encoders * 2]
|
||||
self.cached_key = states[num_encoders * 2 : num_encoders * 3]
|
||||
self.cached_val = states[num_encoders * 3 : num_encoders * 4]
|
||||
self.cached_val2 = states[num_encoders * 4 : num_encoders * 5]
|
||||
self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6]
|
||||
self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7]
|
||||
|
||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
"""
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
encoder_out = encoder_out.squeeze(0)
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
waves = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
wave_samples = torch.cat([waves, tail_padding])
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
encoder_out = model.run_encoder(frames)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model,
|
||||
encoder_out,
|
||||
context_size,
|
||||
decoder_out,
|
||||
hyp,
|
||||
)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/optim.py
|
355
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py
Executable file
355
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py
Executable file
@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7_streaming/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling_converter.py
|
@ -0,0 +1,419 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
|
||||
--tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \
|
||||
--encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \
|
||||
--encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \
|
||||
--decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \
|
||||
--decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \
|
||||
--joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \
|
||||
--joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \
|
||||
./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav
|
||||
|
||||
You can find pretrained models at
|
||||
- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
|
||||
- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import ncnn
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-param-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-param-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-param-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-bin-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_filename",
|
||||
type=str,
|
||||
help="Path to foo.wav",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, args):
|
||||
self.init_encoder(args)
|
||||
self.init_decoder(args)
|
||||
self.init_joiner(args)
|
||||
|
||||
# Please change the parameters according to your model
|
||||
self.num_encoder_layers = to_int_tuple("2,4,3,2,4")
|
||||
self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model
|
||||
self.attention_dims = to_int_tuple("192,192,192,192,192")
|
||||
self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
|
||||
self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")
|
||||
|
||||
self.decode_chunk_size = 32 // 2
|
||||
num_left_chunks = 4
|
||||
self.left_context_length = self.decode_chunk_size * num_left_chunks # 64
|
||||
|
||||
self.chunk_length = self.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
self.T = self.chunk_length + pad_length
|
||||
|
||||
def get_init_states(self) -> List[torch.Tensor]:
|
||||
cached_len_list = []
|
||||
cached_avg_list = []
|
||||
cached_key_list = []
|
||||
cached_val_list = []
|
||||
cached_val2_list = []
|
||||
cached_conv1_list = []
|
||||
cached_conv2_list = []
|
||||
|
||||
for i in range(len(self.num_encoder_layers)):
|
||||
num_layers = self.num_encoder_layers[i]
|
||||
ds = self.zipformer_downsampling_factors[i]
|
||||
attention_dim = self.attention_dims[i]
|
||||
left_context_length = self.left_context_length // ds
|
||||
encoder_dim = self.encoder_dims[i]
|
||||
cnn_module_kernel = self.cnn_module_kernels[i]
|
||||
|
||||
cached_len_list.append(torch.zeros(num_layers))
|
||||
cached_avg_list.append(torch.zeros(num_layers, encoder_dim))
|
||||
cached_key_list.append(
|
||||
torch.zeros(num_layers, left_context_length, attention_dim)
|
||||
)
|
||||
cached_val_list.append(
|
||||
torch.zeros(num_layers, left_context_length, attention_dim // 2)
|
||||
)
|
||||
cached_val2_list.append(
|
||||
torch.zeros(num_layers, left_context_length, attention_dim // 2)
|
||||
)
|
||||
cached_conv1_list.append(
|
||||
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
|
||||
)
|
||||
cached_conv2_list.append(
|
||||
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
|
||||
)
|
||||
|
||||
states = (
|
||||
cached_len_list
|
||||
+ cached_avg_list
|
||||
+ cached_key_list
|
||||
+ cached_val_list
|
||||
+ cached_val2_list
|
||||
+ cached_conv1_list
|
||||
+ cached_conv2_list
|
||||
)
|
||||
|
||||
return states
|
||||
|
||||
def init_encoder(self, args):
|
||||
encoder_net = ncnn.Net()
|
||||
encoder_net.opt.use_packing_layout = False
|
||||
encoder_net.opt.use_fp16_storage = False
|
||||
encoder_net.opt.num_threads = 4
|
||||
|
||||
encoder_param = args.encoder_param_filename
|
||||
encoder_model = args.encoder_bin_filename
|
||||
|
||||
encoder_net.load_param(encoder_param)
|
||||
encoder_net.load_model(encoder_model)
|
||||
|
||||
self.encoder_net = encoder_net
|
||||
|
||||
def init_decoder(self, args):
|
||||
decoder_param = args.decoder_param_filename
|
||||
decoder_model = args.decoder_bin_filename
|
||||
|
||||
decoder_net = ncnn.Net()
|
||||
decoder_net.opt.num_threads = 4
|
||||
|
||||
decoder_net.load_param(decoder_param)
|
||||
decoder_net.load_model(decoder_model)
|
||||
|
||||
self.decoder_net = decoder_net
|
||||
|
||||
def init_joiner(self, args):
|
||||
joiner_param = args.joiner_param_filename
|
||||
joiner_model = args.joiner_bin_filename
|
||||
joiner_net = ncnn.Net()
|
||||
joiner_net.opt.num_threads = 4
|
||||
|
||||
joiner_net.load_param(joiner_param)
|
||||
joiner_net.load_model(joiner_model)
|
||||
|
||||
self.joiner_net = joiner_net
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (T, C)
|
||||
states:
|
||||
A list of tensors. len(states) == self.num_layers * 4
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, a tensor of shape (T, encoder_dim).
|
||||
- next_states, a list of tensors containing the next states
|
||||
"""
|
||||
with self.encoder_net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
|
||||
for i in range(len(states)):
|
||||
name = f"in{i+1}"
|
||||
ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone())
|
||||
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
|
||||
out_states: List[torch.Tensor] = []
|
||||
for i in range(len(states)):
|
||||
name = f"out{i+1}"
|
||||
ret, ncnn_out_state = ex.extract(name)
|
||||
assert ret == 0, ret
|
||||
ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
|
||||
|
||||
if i < len(self.num_encoder_layers):
|
||||
# for cached_len, we need to discard the last dim
|
||||
ncnn_out_state = ncnn_out_state.squeeze(1)
|
||||
|
||||
out_states.append(ncnn_out_state)
|
||||
|
||||
return encoder_out, out_states
|
||||
|
||||
def run_decoder(self, decoder_input):
|
||||
assert decoder_input.dtype == torch.int32
|
||||
|
||||
with self.decoder_net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
return decoder_out
|
||||
|
||||
def run_joiner(self, encoder_out, decoder_out):
|
||||
with self.joiner_net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
return joiner_out
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Model,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
|
||||
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||
else:
|
||||
assert decoder_out.ndim == 1
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t]
|
||||
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
|
||||
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = Model(args)
|
||||
|
||||
sound_file = args.sound_filename
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
states = model.get_init_states()
|
||||
logging.info(f"number of states: {len(states)}")
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.T
|
||||
offset = model.chunk_length
|
||||
|
||||
chunk = int(1 * sample_rate) # 0.2 second
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
encoder_out, states = model.run_encoder(frames, states)
|
||||
hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
context_size = 2
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(sound_file)
|
||||
logging.info(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
|
612
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
612
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
@ -0,0 +1,612 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 2000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from commonvoice_fr import CommonVoiceAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states, unstack_states
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Supported decoding methods are:
|
||||
greedy_search
|
||||
modified_beam_search
|
||||
fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=32,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
decode_streams: List[DecodeStream],
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decode_streams:
|
||||
A List of DecodeStream, each belonging to a utterance.
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
||||
# factor in encoders is 8.
|
||||
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
||||
tail_length = 23
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
states = stack_states(states)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=processed_lens,
|
||||
streams=decode_streams,
|
||||
beam=params.beam,
|
||||
max_states=params.max_states,
|
||||
max_contexts=params.max_contexts,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
|
||||
states = unstack_states(new_states)
|
||||
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
return finished_streams
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
opts = FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
idx = 0
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
initial_states = model.encoder.get_init_state(device=device)
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
initial_states=initial_states,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
)
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
if audio.max() > 1 or audio.min() < -1:
|
||||
audio = audio / max(abs(audio.max()), abs(audio.min()))
|
||||
print(audio)
|
||||
print(audio.max())
|
||||
print(audio.min())
|
||||
print(cut)
|
||||
idx += 1
|
||||
print(idx)
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
CommonVoiceAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
commonvoice = CommonVoiceAsrDataModule(args)
|
||||
test_cuts = commonvoice.test_cuts()
|
||||
test_sets = "test-cv"
|
||||
|
||||
results_dict = decode_dataset(
|
||||
cuts=test_cuts,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_sets,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
150
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
150
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless7_streaming/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Test jit script
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
print("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
|
||||
|
||||
def test_model_jit_trace():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
# Test encoder
|
||||
def _test_encoder():
|
||||
encoder = model.encoder
|
||||
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
T = params.decode_chunk_len + 7
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
states = encoder.get_init_state(device=x.device)
|
||||
encoder.__class__.forward = encoder.__class__.streaming_forward
|
||||
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
|
||||
|
||||
states1 = encoder.get_init_state(device=x.device)
|
||||
states2 = traced_encoder.get_init_state(device=x.device)
|
||||
for i in range(5):
|
||||
x = torch.randn(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
|
||||
y2, _, states2 = traced_encoder(x, x_lens, states2)
|
||||
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
|
||||
|
||||
# Test decoder
|
||||
def _test_decoder():
|
||||
decoder = model.decoder
|
||||
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
|
||||
d1 = decoder(y, need_pad)
|
||||
d2 = traced_decoder(y, need_pad)
|
||||
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
|
||||
|
||||
# Test joiner
|
||||
def _test_joiner():
|
||||
joiner = model.joiner
|
||||
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
|
||||
j1 = joiner(encoder_out, decoder_out)
|
||||
j2 = traced_joiner(encoder_out, decoder_out)
|
||||
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
|
||||
|
||||
_test_encoder()
|
||||
_test_decoder()
|
||||
_test_joiner()
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
test_model_jit_trace()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1256
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
1256
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1257
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
Executable file
1257
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
Executable file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -426,7 +426,6 @@ class NgramCounts:
|
||||
if __name__ == "__main__":
|
||||
|
||||
ngram_counts = NgramCounts(args.ngram_order)
|
||||
|
||||
if args.text is None:
|
||||
ngram_counts.add_raw_counts_from_standard_input()
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user