mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
fix the symlinks
This commit is contained in:
parent
f31d31ff1c
commit
3b84a6aff9
@ -1,166 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
1
egs/commonvoice/ASR/local/compile_hlg.py
Symbolic link
1
egs/commonvoice/ASR/local/compile_hlg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_hlg.py
|
@ -1,141 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates LG from
|
||||
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated LG is saved in $lang_dir/LG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
|
||||
Return:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/G_3_gram.pt").is_file():
|
||||
logging.info("Loading pre-compiled G_3_gram")
|
||||
d = torch.load(f"{lang_dir}/lm/G_3_gram.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open(f"{lang_dir}/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/G_3_gram.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
# assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
# LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
if isinstance(LG.aux_labels, k2.RaggedTensor):
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
else:
|
||||
LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
return LG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "LG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
1
egs/commonvoice/ASR/local/compile_lg.py
Symbolic link
1
egs/commonvoice/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
@ -1,151 +0,0 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
class DecodeStream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
cut_id: str,
|
||||
initial_states: List[torch.Tensor],
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
initial_states:
|
||||
Initial decode states of the model, e.g. the return value of
|
||||
`get_init_state` in conformer.py
|
||||
decoding_graph:
|
||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||
Used only when decoding_method is fast_beam_search.
|
||||
device:
|
||||
The device to run this stream.
|
||||
"""
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
assert decoding_graph is not None
|
||||
assert device == decoding_graph.device
|
||||
|
||||
self.params = params
|
||||
self.cut_id = cut_id
|
||||
self.LOG_EPS = math.log(1e-10)
|
||||
|
||||
self.states = initial_states
|
||||
|
||||
# It contains a 2-D tensors representing the feature frames.
|
||||
self.features: torch.Tensor = None
|
||||
|
||||
self.num_frames: int = 0
|
||||
# how many frames have been processed. (before subsampling).
|
||||
# we only modify this value in `func:get_feature_frames`.
|
||||
self.num_processed_frames: int = 0
|
||||
|
||||
self._done: bool = False
|
||||
|
||||
# The transcript of current utterance.
|
||||
self.ground_truth: str = ""
|
||||
|
||||
# The decoding result (partial or final) of current utterance.
|
||||
self.hyp: List = []
|
||||
|
||||
# how many frames have been processed, after subsampling (i.e. a
|
||||
# cumulative sum of the second return value of
|
||||
# encoder.streaming_forward
|
||||
self.done_frames: int = 0
|
||||
|
||||
# It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2
|
||||
# 1) feature embedding: out_lens=(x_lens-7)//2
|
||||
# 2) output subsampling: out_lens=(out_lens+1)//2
|
||||
self.pad_length = 7
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
self.hyps = HypothesisList()
|
||||
self.hyps.add(
|
||||
Hypothesis(
|
||||
ys=[params.blank_id] * params.context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
|
||||
decoding_graph
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all the features are processed."""
|
||||
return self._done
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.cut_id
|
||||
|
||||
def set_features(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
tail_pad_len: int = 0,
|
||||
) -> None:
|
||||
"""Set features tensor of current utterance."""
|
||||
assert features.dim() == 2, features.dim()
|
||||
self.features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, self.pad_length + tail_pad_len),
|
||||
mode="constant",
|
||||
value=self.LOG_EPS,
|
||||
)
|
||||
self.num_frames = self.features.size(0)
|
||||
|
||||
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Consume chunk_size frames of features"""
|
||||
chunk_length = chunk_size + self.pad_length
|
||||
|
||||
ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
|
||||
|
||||
ret_features = self.features[
|
||||
self.num_processed_frames : self.num_processed_frames + ret_length # noqa
|
||||
]
|
||||
|
||||
self.num_processed_frames += chunk_size
|
||||
if self.num_processed_frames >= self.num_frames:
|
||||
self._done = True
|
||||
|
||||
return ret_features, ret_length
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.params.decoding_method == "greedy_search":
|
||||
return self.hyp[self.params.context_size :] # noqa
|
||||
elif self.params.decoding_method == "modified_beam_search":
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.params.context_size :] # noqa
|
||||
else:
|
||||
assert self.params.decoding_method == "fast_beam_search"
|
||||
return self.hyp
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
|
@ -1,367 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Please see
|
||||
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||
for more details about how to use this file.
|
||||
|
||||
We use
|
||||
https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||
to demonstrate the usage of this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_char_bpe/L.pt"
|
||||
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
|
||||
git lfs pull --include "data/lang_char_bpe/Linv.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export to ncnn
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||
--lang-dir $repo/data/lang_char_bpe \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--num-encoder-layers "2,4,3,2,4" \
|
||||
--feedforward-dims "1024,1024,1536,1536,1024" \
|
||||
--nhead "8,8,8,8,8" \
|
||||
--encoder-dims "384,384,384,384,384" \
|
||||
--attention-dims "192,192,192,192,192" \
|
||||
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||
--cnn-module-kernels "31,31,31,31,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
|
||||
cd $repo/exp
|
||||
|
||||
pnnx encoder_jit_trace-pnnx.pt
|
||||
pnnx decoder_jit_trace-pnnx.pt
|
||||
pnnx joiner_jit_trace-pnnx.pt
|
||||
|
||||
You can find converted models at
|
||||
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
|
||||
|
||||
See ./streaming-ncnn-decode.py
|
||||
and
|
||||
https://github.com/k2-fsa/sherpa-ncnn
|
||||
for usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||
|
||||
decode_chunk_len = encoder_model.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
T = decode_chunk_len + pad_length # 32 + 7 = 39
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
states = encoder_model.get_init_state()
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
|
@ -1,369 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Please see
|
||||
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||
for more details about how to use this file.
|
||||
|
||||
We use
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
to demonstrate the usage of this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export to ncnn
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
\
|
||||
--decode-chunk-len 32 \
|
||||
--num-encoder-layers "2,4,3,2,4" \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--nhead "8,8,8,8,8" \
|
||||
--encoder-dims "384,384,384,384,384" \
|
||||
--attention-dims "192,192,192,192,192" \
|
||||
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||
--cnn-module-kernels "31,31,31,31,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
|
||||
cd $repo/exp
|
||||
|
||||
pnnx encoder_jit_trace-pnnx.pt
|
||||
pnnx decoder_jit_trace-pnnx.pt
|
||||
pnnx joiner_jit_trace-pnnx.pt
|
||||
|
||||
You can find converted models at
|
||||
https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
|
||||
|
||||
See ./streaming-ncnn-decode.py
|
||||
and
|
||||
https://github.com/k2-fsa/sherpa-ncnn
|
||||
for usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||
|
||||
decode_chunk_len = encoder_model.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
T = decode_chunk_len + pad_length # 32 + 7 = 39
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
states = encoder_model.get_init_state()
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
|
@ -1,647 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
1. Download the pre-trained model
|
||||
cd egs/librispeech/ASR
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
2. Export the model to ONNX
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
It will generate the following 3 files in $repo/exp
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
See ./onnx_pretrained.py for how to use the exported models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
|
||||
"""Please see the help information of Zipformer.streaming_forward"""
|
||||
N = x.size(0)
|
||||
T = x.size(1)
|
||||
x_lens = torch.tensor([T] * N, device=x.device)
|
||||
|
||||
output, _, new_states = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
output = self.encoder_proj(output)
|
||||
# Now output is of shape (N, T, joiner_dim)
|
||||
|
||||
return output, new_states
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""
|
||||
Onnx model inputs:
|
||||
- 0: src
|
||||
- many state tensors (the exact number depending on the actual model)
|
||||
Onnx model outputs:
|
||||
- 0: output, its shape is (N, T, joiner_dim)
|
||||
- many state tensors (the exact number depending on the actual model)
|
||||
Args:
|
||||
encoder_model:
|
||||
The model to be exported
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
)
|
||||
|
||||
decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
T = decode_chunk_len + pad_length
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"pad_length: {pad_length}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
|
||||
init_state = encoder_model.encoder.get_init_state()
|
||||
|
||||
num_encoders = encoder_model.encoder.num_encoders
|
||||
logging.info(f"num_encoders: {num_encoders}")
|
||||
logging.info(f"len(init_state): {len(init_state)}")
|
||||
|
||||
inputs = {}
|
||||
input_names = ["x"]
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, name, N):
|
||||
for i, s in enumerate(tensors):
|
||||
logging.info(f"{name}_{i}.shape: {s.shape}")
|
||||
inputs[f"{name}_{i}"] = {N: "N"}
|
||||
outputs[f"new_{name}_{i}"] = {N: "N"}
|
||||
input_names.append(f"{name}_{i}")
|
||||
output_names.append(f"new_{name}_{i}")
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims))
|
||||
attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims))
|
||||
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels))
|
||||
ds = encoder_model.encoder.zipformer_downsampling_factors
|
||||
left_context_len = encoder_model.encoder.left_context_len
|
||||
left_context_len = [left_context_len // k for k in ds]
|
||||
left_context_len = ",".join(map(str, left_context_len))
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||
"T": str(T), # 39
|
||||
"num_encoder_layers": num_encoder_layers,
|
||||
"encoder_dims": encoder_dims,
|
||||
"attention_dims": attention_dims,
|
||||
"cnn_module_kernels": cnn_module_kernels,
|
||||
"left_context_len": left_context_len,
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
# (num_encoder_layers, 1)
|
||||
cached_len = init_state[num_encoders * 0 : num_encoders * 1]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim)
|
||||
cached_avg = init_state[num_encoders * 1 : num_encoders * 2]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim)
|
||||
cached_key = init_state[num_encoders * 2 : num_encoders * 3]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val = init_state[num_encoders * 3 : num_encoders * 4]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val2 = init_state[num_encoders * 4 : num_encoders * 5]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7]
|
||||
|
||||
build_inputs_outputs(cached_len, "cached_len", 1)
|
||||
build_inputs_outputs(cached_avg, "cached_avg", 1)
|
||||
build_inputs_outputs(cached_key, "cached_key", 2)
|
||||
build_inputs_outputs(cached_val, "cached_val", 2)
|
||||
build_inputs_outputs(cached_val2, "cached_val2", 2)
|
||||
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
|
||||
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
|
||||
|
||||
logging.info(inputs)
|
||||
logging.info(outputs)
|
||||
logging.info(input_names)
|
||||
logging.info(output_names)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
"x": {0: "N"},
|
||||
"encoder_out": {0: "N"},
|
||||
**inputs,
|
||||
**outputs,
|
||||
},
|
||||
)
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
The exported model has one input:
|
||||
- y: a torch.int64 tensor of shape (N, context_size)
|
||||
and has one output:
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
Note: The argument need_pad is fixed to False.
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
and produces one output:
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
if params.use_averaged_model:
|
||||
suffix += "-with-averaged-model"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
|
@ -1,878 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
||||
|
||||
(3) Export to ONNX format with pretrained.pt
|
||||
|
||||
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--fp16 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(4) Export to ONNX format for triton server
|
||||
|
||||
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--fp16 \
|
||||
--onnx-triton 1 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa/tree/master/triton
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnxruntime
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-triton",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --onnx would export model into the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="whether to export fp16 onnx model, default false",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
|
||||
for a, b in zip(xlist, blist):
|
||||
try:
|
||||
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||
except AssertionError as error:
|
||||
if tolerate_small_mismatch:
|
||||
print("small mismatch detected", error)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
batch_size = 17
|
||||
seq_len = 101
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
initial_states = [encoder_model.get_init_state() for _ in range(batch_size)]
|
||||
states = stack_states(initial_states)
|
||||
|
||||
left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks
|
||||
encoder_attention_dim = encoder_model.encoders[0].attention_dim
|
||||
|
||||
len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15
|
||||
avg_cache = torch.cat(
|
||||
states[encoder_model.num_encoders : 2 * encoder_model.num_encoders]
|
||||
).transpose(
|
||||
0, 1
|
||||
) # [B,15,384]
|
||||
cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose(
|
||||
0, 1
|
||||
) # [B,2*15,384,cnn_kernel-1]
|
||||
pad_tensors = [
|
||||
torch.nn.functional.pad(
|
||||
tensor,
|
||||
(
|
||||
0,
|
||||
encoder_attention_dim - tensor.shape[-1],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
left_context_len - tensor.shape[1],
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
for tensor in states[
|
||||
2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders
|
||||
]
|
||||
]
|
||||
attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||
|
||||
encoder_model_wrapper = OnnxStreamingEncoder(encoder_model)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model_wrapper,
|
||||
(x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"x",
|
||||
"x_lens",
|
||||
"len_cache",
|
||||
"avg_cache",
|
||||
"attn_cache",
|
||||
"cnn_cache",
|
||||
],
|
||||
output_names=[
|
||||
"encoder_out",
|
||||
"encoder_out_lens",
|
||||
"new_len_cache",
|
||||
"new_avg_cache",
|
||||
"new_attn_cache",
|
||||
"new_cnn_cache",
|
||||
],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
"len_cache": {0: "N"},
|
||||
"avg_cache": {0: "N"},
|
||||
"attn_cache": {0: "N"},
|
||||
"cnn_cache": {0: "N"},
|
||||
"new_len_cache": {0: "N"},
|
||||
"new_avg_cache": {0: "N"},
|
||||
"new_attn_cache": {0: "N"},
|
||||
"new_cnn_cache": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
# Test onnx encoder with torch native encoder
|
||||
encoder_model.eval()
|
||||
(
|
||||
encoder_out_torch,
|
||||
encoder_out_lens_torch,
|
||||
new_states_torch,
|
||||
) = encoder_model.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
ort_session = onnxruntime.InferenceSession(
|
||||
str(encoder_filename), providers=["CPUExecutionProvider"]
|
||||
)
|
||||
ort_inputs = {
|
||||
"x": x.numpy(),
|
||||
"x_lens": x_lens.numpy(),
|
||||
"len_cache": len_cache.numpy(),
|
||||
"avg_cache": avg_cache.numpy(),
|
||||
"attn_cache": attn_cache.numpy(),
|
||||
"cnn_cache": cnn_cache.numpy(),
|
||||
}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
assert test_acc(
|
||||
[encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2]
|
||||
)
|
||||
logging.info(f"{encoder_filename} acc test succeeded.")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx_triton(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
|
||||
decoder_model = TritonOnnxDecoder(decoder_model)
|
||||
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y,),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
|
||||
The exported encoder_proj model has one input:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
The exported decoder_proj model has one input:
|
||||
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
project_input = False
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx_triton(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
and has one output:
|
||||
- joiner_out: a tensor of shape (N, vocab_size)
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
joiner_model = TritonOnnxJoiner(joiner_model)
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.onnx:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
opset_version = 13
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
if not params.onnx_triton:
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
else:
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx_triton(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx_triton(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
if params.fp16:
|
||||
try:
|
||||
import onnxmltools
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
except ImportError:
|
||||
print("Please install onnxmltools!")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
|
||||
encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx"
|
||||
export_onnx_fp16(encoder_filename, encoder_fp16_filename)
|
||||
|
||||
decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx"
|
||||
export_onnx_fp16(decoder_filename, decoder_fp16_filename)
|
||||
|
||||
joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx"
|
||||
export_onnx_fp16(joiner_filename, joiner_fp16_filename)
|
||||
|
||||
if not params.onnx_triton:
|
||||
encoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_encoder_proj.onnx"
|
||||
)
|
||||
encoder_proj_fp16_filename = (
|
||||
params.exp_dir / "joiner_encoder_proj_fp16.onnx"
|
||||
)
|
||||
export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename)
|
||||
|
||||
decoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_decoder_proj.onnx"
|
||||
)
|
||||
decoder_proj_fp16_filename = (
|
||||
params.exp_dir / "joiner_decoder_proj_fp16.onnx"
|
||||
)
|
||||
export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename)
|
||||
|
||||
elif params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
|
@ -1,278 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7_streaming/exp/cpu_jit.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-len",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The chunk size for decoding (in frames before subsampling)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
model.encoder.decode_chunk_size = args.decode_chunk_len // 2
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
|
@ -1,313 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--decode-chunk-len 32
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
params: AttributeDict,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
decode_chunk_len = params.decode_chunk_len # before subsampling
|
||||
pad_length = 7
|
||||
s = f"decode_chunk_len: {decode_chunk_len}"
|
||||
logging.info(s)
|
||||
assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
|
||||
encoder_model.decode_chunk_size,
|
||||
decode_chunk_len,
|
||||
)
|
||||
|
||||
T = decode_chunk_len + pad_length
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
states = encoder_model.get_init_state(device=x.device)
|
||||
|
||||
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
|
@ -1,295 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models exported by `torch.jit.trace()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--decode-chunk-len 32
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--decode-chunk-len 32 \
|
||||
/path/to/foo.wav \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-len",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The chunk size for decoding (in frames before subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
assert encoder_out.ndim == 2
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
|
||||
# decoder_input.shape (1,, 1 context_size)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
else:
|
||||
assert decoder_out.ndim == 2
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for i in range(T):
|
||||
cur_encoder_out = encoder_out[i : i + 1]
|
||||
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
|
||||
decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
encoder = torch.jit.load(args.encoder_model_filename)
|
||||
decoder = torch.jit.load(args.decoder_model_filename)
|
||||
joiner = torch.jit.load(args.joiner_model_filename)
|
||||
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
joiner.eval()
|
||||
|
||||
encoder.to(device)
|
||||
decoder.to(device)
|
||||
joiner.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor(args.sample_rate)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
logging.info("Decoding started")
|
||||
chunk_length = args.decode_chunk_len
|
||||
assert encoder.decode_chunk_size == chunk_length // 2, (
|
||||
encoder.decode_chunk_size,
|
||||
chunk_length,
|
||||
)
|
||||
|
||||
# we subsample features with ((x_len - 7) // 2 + 1) // 2
|
||||
pad_length = 7
|
||||
T = chunk_length + pad_length
|
||||
|
||||
logging.info(f"chunk_length: {chunk_length}")
|
||||
|
||||
states = encoder.get_init_state(device)
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
chunk = int(0.25 * args.sample_rate) # 0.2 second
|
||||
num_processed_frames = 0
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
logging.info(f"{start}/{wave_samples.numel()}")
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=args.sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= T:
|
||||
frames = []
|
||||
for i in range(T):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
frames = torch.cat(frames, dim=0).unsqueeze(0)
|
||||
x_lens = torch.tensor([T], dtype=torch.int32)
|
||||
encoder_out, out_lens, states = encoder(
|
||||
x=frames,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
num_processed_frames += chunk_length
|
||||
|
||||
hyp, decoder_out = greedy_search(
|
||||
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
|
||||
)
|
||||
|
||||
context_size = 2
|
||||
logging.info(args.sound_file)
|
||||
logging.info(sp.decode(hyp[context_size:]))
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
|
@ -1,260 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script checks that exported ONNX models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model via torch.jit.trace()
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp
|
||||
|
||||
- encoder_jit_trace.pt
|
||||
- decoder_jit_trace.pt
|
||||
- joiner_jit_trace.pt
|
||||
|
||||
3. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
4. Run this file
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_check.py \
|
||||
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
from zipformer import stack_states
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the ONNX joiner model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
torch_encoder_model: torch.jit.ScriptModule,
|
||||
torch_encoder_proj_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
T = onnx_model.segment
|
||||
C = 80
|
||||
x_lens = torch.tensor([T] * N)
|
||||
torch_states = [torch_encoder_model.get_init_state() for _ in range(N)]
|
||||
torch_states = stack_states(torch_states)
|
||||
|
||||
onnx_model.init_encoder_states(N)
|
||||
|
||||
for i in range(5):
|
||||
logging.info(f"test_encoder: iter {i}")
|
||||
x = torch.rand(N, T, C)
|
||||
torch_encoder_out, _, torch_states = torch_encoder_model(
|
||||
x, x_lens, torch_states
|
||||
)
|
||||
torch_encoder_out = torch_encoder_proj_model(torch_encoder_out)
|
||||
|
||||
onnx_encoder_out = onnx_model.run_encoder(x)
|
||||
|
||||
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), (
|
||||
(torch_encoder_out - onnx_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
torch_decoder_model: torch.jit.ScriptModule,
|
||||
torch_decoder_proj_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
context_size = onnx_model.context_size
|
||||
vocab_size = onnx_model.vocab_size
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_decoder: iter {i}, N={N}")
|
||||
x = torch.randint(
|
||||
low=1,
|
||||
high=vocab_size,
|
||||
size=(N, context_size),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False]))
|
||||
torch_decoder_out = torch_decoder_proj_model(torch_decoder_out)
|
||||
torch_decoder_out = torch_decoder_out.squeeze(1)
|
||||
|
||||
onnx_decoder_out = onnx_model.run_decoder(x)
|
||||
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
|
||||
(torch_decoder_out - onnx_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
torch_joiner_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1]
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_joiner: iter {i}, N={N}")
|
||||
encoder_out = torch.rand(N, encoder_dim)
|
||||
decoder_out = torch.rand(N, decoder_dim)
|
||||
|
||||
projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out)
|
||||
projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out)
|
||||
|
||||
torch_joiner_out = torch_joiner_model(encoder_out, decoder_out)
|
||||
onnx_joiner_out = onnx_model.run_joiner(
|
||||
projected_encoder_out, projected_decoder_out
|
||||
)
|
||||
|
||||
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
|
||||
(torch_joiner_out - onnx_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_encoder_model = torch.jit.load(args.jit_encoder_filename)
|
||||
torch_decoder_model = torch.jit.load(args.jit_decoder_filename)
|
||||
torch_joiner_model = torch.jit.load(args.jit_joiner_filename)
|
||||
|
||||
onnx_model = OnnxModel(
|
||||
encoder_model_filename=args.onnx_encoder_filename,
|
||||
decoder_model_filename=args.onnx_decoder_filename,
|
||||
joiner_model_filename=args.onnx_joiner_filename,
|
||||
)
|
||||
|
||||
logging.info("Test encoder")
|
||||
# When exporting the model to onnx, we have already put the encoder_proj
|
||||
# inside the encoder.
|
||||
test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model)
|
||||
|
||||
logging.info("Test decoder")
|
||||
# When exporting the model to onnx, we have already put the decoder_proj
|
||||
# inside the decoder.
|
||||
test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model)
|
||||
|
||||
logging.info("Test joiner")
|
||||
test_joiner(torch_joiner_model, onnx_model)
|
||||
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/38342
|
||||
# and https://github.com/pytorch/pytorch/issues/33354
|
||||
#
|
||||
# If we don't do this, the delay increases whenever there is
|
||||
# a new request that changes the actual batch size.
|
||||
# If you use `py-spy dump --pid <server-pid> --native`, you will
|
||||
# see a lot of time is spent in re-compiling the torch script model.
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20230207)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
|
@ -1,231 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class OnnxStreamingEncoder(torch.nn.Module):
|
||||
"""This class warps the streaming Zipformer to reduce the number of
|
||||
state tensors for onnx.
|
||||
https://github.com/k2-fsa/icefall/pull/831
|
||||
"""
|
||||
|
||||
def __init__(self, encoder):
|
||||
"""
|
||||
Args:
|
||||
encoder: An instance of Zipformer Class
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = encoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
len_cache: torch.tensor,
|
||||
avg_cache: torch.tensor,
|
||||
attn_cache: torch.tensor,
|
||||
cnn_cache: torch.tensor,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
len_cache:
|
||||
The cached numbers of past frames.
|
||||
avg_cache:
|
||||
The cached average tensors.
|
||||
attn_cache:
|
||||
The cached key tensors of the first attention modules.
|
||||
The cached value tensors of the first attention modules.
|
||||
The cached value tensors of the second attention modules.
|
||||
cnn_cache:
|
||||
The cached left contexts of the first convolution modules.
|
||||
The cached left contexts of the second convolution modules.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
|
||||
"""
|
||||
num_encoder_layers = []
|
||||
encoder_attention_dims = []
|
||||
states = []
|
||||
for i, encoder in enumerate(self.model.encoders):
|
||||
num_encoder_layers.append(encoder.num_layers)
|
||||
encoder_attention_dims.append(encoder.attention_dim)
|
||||
|
||||
len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(len_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
avg_cache = avg_cache.transpose(0, 1) # [15, B, 384]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(avg_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192]
|
||||
left_context_len = attn_cache.shape[1]
|
||||
offset = 0
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[offset : offset + num_layer, : left_context_len // ds]
|
||||
)
|
||||
offset += num_layer
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
encoder_attention_dim = encoder_attention_dims[i]
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[
|
||||
offset : offset + num_layer,
|
||||
: left_context_len // ds,
|
||||
:,
|
||||
: encoder_attention_dim // 2,
|
||||
]
|
||||
)
|
||||
offset += num_layer
|
||||
for i, num_layer in enumerate(num_encoder_layers):
|
||||
ds = self.model.zipformer_downsampling_factors[i]
|
||||
states.append(
|
||||
attn_cache[
|
||||
offset : offset + num_layer,
|
||||
: left_context_len // ds,
|
||||
:,
|
||||
: encoder_attention_dim // 2,
|
||||
]
|
||||
)
|
||||
offset += num_layer
|
||||
|
||||
cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1]
|
||||
offset = 0
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(cnn_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
for num_layer in num_encoder_layers:
|
||||
states.append(cnn_cache[offset : offset + num_layer])
|
||||
offset += num_layer
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = self.model.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose(
|
||||
0, 1
|
||||
) # [B,15]
|
||||
new_avg_cache = torch.cat(
|
||||
states[self.model.num_encoders : 2 * self.model.num_encoders]
|
||||
).transpose(
|
||||
0, 1
|
||||
) # [B,15,384]
|
||||
new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose(
|
||||
0, 1
|
||||
) # [B,2*15,384,cnn_kernel-1]
|
||||
assert len(set(encoder_attention_dims)) == 1
|
||||
pad_tensors = [
|
||||
torch.nn.functional.pad(
|
||||
tensor,
|
||||
(
|
||||
0,
|
||||
encoder_attention_dims[0] - tensor.shape[-1],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
left_context_len - tensor.shape[1],
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
for tensor in states[
|
||||
2 * self.model.num_encoders : 5 * self.model.num_encoders
|
||||
]
|
||||
]
|
||||
new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||
|
||||
return (
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_len_cache,
|
||||
new_avg_cache,
|
||||
new_attn_cache,
|
||||
new_cnn_cache,
|
||||
)
|
||||
|
||||
|
||||
class TritonOnnxDecoder(torch.nn.Module):
|
||||
"""This class warps the Decoder in decoder.py
|
||||
to remove the scalar input "need_pad".
|
||||
Triton currently doesn't support scalar input.
|
||||
https://github.com/triton-inference-server/server/issues/2333
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: torch.nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
decoder: A instance of Decoder
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = decoder
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
# False to not pad the input. Should be False during inference.
|
||||
need_pad = False
|
||||
return self.model(y, need_pad)
|
||||
|
||||
|
||||
class TritonOnnxJoiner(torch.nn.Module):
|
||||
"""This class warps the Joiner in joiner.py
|
||||
to remove the scalar input "project_input".
|
||||
Triton currently doesn't support scalar input.
|
||||
https://github.com/triton-inference-server/server/issues/2333
|
||||
"project_input" is set to True.
|
||||
Triton solutions only need export joiner to a single joiner.onnx.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
joiner: torch.nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = joiner
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
# Apply input projections encoder_proj and decoder_proj.
|
||||
project_input = True
|
||||
return self.model(encoder_out, decoder_out, project_input)
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
|
@ -1,512 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx.py
|
||||
and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files in $repo/exp
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file with the exported ONNX models
|
||||
|
||||
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
the exported ONNX models do support batch processing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
def init_encoder_states(self, batch_size: int = 1):
|
||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
|
||||
model_type = encoder_meta["model_type"]
|
||||
assert model_type == "zipformer", model_type
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
T = int(encoder_meta["T"])
|
||||
|
||||
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||
encoder_dims = encoder_meta["encoder_dims"]
|
||||
attention_dims = encoder_meta["attention_dims"]
|
||||
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
|
||||
left_context_len = encoder_meta["left_context_len"]
|
||||
|
||||
def to_int_list(s):
|
||||
return list(map(int, s.split(",")))
|
||||
|
||||
num_encoder_layers = to_int_list(num_encoder_layers)
|
||||
encoder_dims = to_int_list(encoder_dims)
|
||||
attention_dims = to_int_list(attention_dims)
|
||||
cnn_module_kernels = to_int_list(cnn_module_kernels)
|
||||
left_context_len = to_int_list(left_context_len)
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||
logging.info(f"encoder_dims: {encoder_dims}")
|
||||
logging.info(f"attention_dims: {attention_dims}")
|
||||
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
|
||||
logging.info(f"left_context_len: {left_context_len}")
|
||||
|
||||
num_encoders = len(num_encoder_layers)
|
||||
|
||||
cached_len = []
|
||||
cached_avg = []
|
||||
cached_key = []
|
||||
cached_val = []
|
||||
cached_val2 = []
|
||||
cached_conv1 = []
|
||||
cached_conv2 = []
|
||||
|
||||
N = batch_size
|
||||
|
||||
for i in range(num_encoders):
|
||||
cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64))
|
||||
cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i]))
|
||||
cached_key.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], left_context_len[i], N, attention_dims[i]
|
||||
)
|
||||
)
|
||||
cached_val.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i],
|
||||
left_context_len[i],
|
||||
N,
|
||||
attention_dims[i] // 2,
|
||||
)
|
||||
)
|
||||
cached_val2.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i],
|
||||
left_context_len[i],
|
||||
N,
|
||||
attention_dims[i] // 2,
|
||||
)
|
||||
)
|
||||
cached_conv1.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
|
||||
)
|
||||
)
|
||||
cached_conv2.append(
|
||||
torch.zeros(
|
||||
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
|
||||
)
|
||||
)
|
||||
|
||||
self.cached_len = cached_len
|
||||
self.cached_avg = cached_avg
|
||||
self.cached_key = cached_key
|
||||
self.cached_val = cached_val
|
||||
self.cached_val2 = cached_val2
|
||||
self.cached_conv1 = cached_conv1
|
||||
self.cached_conv2 = cached_conv2
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
|
||||
self.segment = T
|
||||
self.offset = decode_chunk_len
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def _build_encoder_input_output(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
encoder_input = {"x": x.numpy()}
|
||||
encoder_output = ["encoder_out"]
|
||||
|
||||
def build_states_input(states: List[torch.Tensor], name: str):
|
||||
for i, s in enumerate(states):
|
||||
if isinstance(s, torch.Tensor):
|
||||
encoder_input[f"{name}_{i}"] = s.numpy()
|
||||
else:
|
||||
encoder_input[f"{name}_{i}"] = s
|
||||
|
||||
encoder_output.append(f"new_{name}_{i}")
|
||||
|
||||
build_states_input(self.cached_len, "cached_len")
|
||||
build_states_input(self.cached_avg, "cached_avg")
|
||||
build_states_input(self.cached_key, "cached_key")
|
||||
build_states_input(self.cached_val, "cached_val")
|
||||
build_states_input(self.cached_val2, "cached_val2")
|
||||
build_states_input(self.cached_conv1, "cached_conv1")
|
||||
build_states_input(self.cached_conv2, "cached_conv2")
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
num_encoders = self.num_encoders
|
||||
|
||||
self.cached_len = states[num_encoders * 0 : num_encoders * 1]
|
||||
self.cached_avg = states[num_encoders * 1 : num_encoders * 2]
|
||||
self.cached_key = states[num_encoders * 2 : num_encoders * 3]
|
||||
self.cached_val = states[num_encoders * 3 : num_encoders * 4]
|
||||
self.cached_val2 = states[num_encoders * 4 : num_encoders * 5]
|
||||
self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6]
|
||||
self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7]
|
||||
|
||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
"""
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
encoder_out = encoder_out.squeeze(0)
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
waves = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
wave_samples = torch.cat([waves, tail_padding])
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
encoder_out = model.run_encoder(frames)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model,
|
||||
encoder_out,
|
||||
context_size,
|
||||
decoder_out,
|
||||
hyp,
|
||||
)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
|
@ -1,355 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7_streaming/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
|
@ -1,419 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
|
||||
--tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \
|
||||
--encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \
|
||||
--encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \
|
||||
--decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \
|
||||
--decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \
|
||||
--joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \
|
||||
--joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \
|
||||
./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav
|
||||
|
||||
You can find pretrained models at
|
||||
- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
|
||||
- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import ncnn
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-param-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-param-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-param-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-bin-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_filename",
|
||||
type=str,
|
||||
help="Path to foo.wav",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, args):
|
||||
self.init_encoder(args)
|
||||
self.init_decoder(args)
|
||||
self.init_joiner(args)
|
||||
|
||||
# Please change the parameters according to your model
|
||||
self.num_encoder_layers = to_int_tuple("2,4,3,2,4")
|
||||
self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model
|
||||
self.attention_dims = to_int_tuple("192,192,192,192,192")
|
||||
self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
|
||||
self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")
|
||||
|
||||
self.decode_chunk_size = 32 // 2
|
||||
num_left_chunks = 4
|
||||
self.left_context_length = self.decode_chunk_size * num_left_chunks # 64
|
||||
|
||||
self.chunk_length = self.decode_chunk_size * 2
|
||||
pad_length = 7
|
||||
self.T = self.chunk_length + pad_length
|
||||
|
||||
def get_init_states(self) -> List[torch.Tensor]:
|
||||
cached_len_list = []
|
||||
cached_avg_list = []
|
||||
cached_key_list = []
|
||||
cached_val_list = []
|
||||
cached_val2_list = []
|
||||
cached_conv1_list = []
|
||||
cached_conv2_list = []
|
||||
|
||||
for i in range(len(self.num_encoder_layers)):
|
||||
num_layers = self.num_encoder_layers[i]
|
||||
ds = self.zipformer_downsampling_factors[i]
|
||||
attention_dim = self.attention_dims[i]
|
||||
left_context_length = self.left_context_length // ds
|
||||
encoder_dim = self.encoder_dims[i]
|
||||
cnn_module_kernel = self.cnn_module_kernels[i]
|
||||
|
||||
cached_len_list.append(torch.zeros(num_layers))
|
||||
cached_avg_list.append(torch.zeros(num_layers, encoder_dim))
|
||||
cached_key_list.append(
|
||||
torch.zeros(num_layers, left_context_length, attention_dim)
|
||||
)
|
||||
cached_val_list.append(
|
||||
torch.zeros(num_layers, left_context_length, attention_dim // 2)
|
||||
)
|
||||
cached_val2_list.append(
|
||||
torch.zeros(num_layers, left_context_length, attention_dim // 2)
|
||||
)
|
||||
cached_conv1_list.append(
|
||||
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
|
||||
)
|
||||
cached_conv2_list.append(
|
||||
torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1)
|
||||
)
|
||||
|
||||
states = (
|
||||
cached_len_list
|
||||
+ cached_avg_list
|
||||
+ cached_key_list
|
||||
+ cached_val_list
|
||||
+ cached_val2_list
|
||||
+ cached_conv1_list
|
||||
+ cached_conv2_list
|
||||
)
|
||||
|
||||
return states
|
||||
|
||||
def init_encoder(self, args):
|
||||
encoder_net = ncnn.Net()
|
||||
encoder_net.opt.use_packing_layout = False
|
||||
encoder_net.opt.use_fp16_storage = False
|
||||
encoder_net.opt.num_threads = 4
|
||||
|
||||
encoder_param = args.encoder_param_filename
|
||||
encoder_model = args.encoder_bin_filename
|
||||
|
||||
encoder_net.load_param(encoder_param)
|
||||
encoder_net.load_model(encoder_model)
|
||||
|
||||
self.encoder_net = encoder_net
|
||||
|
||||
def init_decoder(self, args):
|
||||
decoder_param = args.decoder_param_filename
|
||||
decoder_model = args.decoder_bin_filename
|
||||
|
||||
decoder_net = ncnn.Net()
|
||||
decoder_net.opt.num_threads = 4
|
||||
|
||||
decoder_net.load_param(decoder_param)
|
||||
decoder_net.load_model(decoder_model)
|
||||
|
||||
self.decoder_net = decoder_net
|
||||
|
||||
def init_joiner(self, args):
|
||||
joiner_param = args.joiner_param_filename
|
||||
joiner_model = args.joiner_bin_filename
|
||||
joiner_net = ncnn.Net()
|
||||
joiner_net.opt.num_threads = 4
|
||||
|
||||
joiner_net.load_param(joiner_param)
|
||||
joiner_net.load_model(joiner_model)
|
||||
|
||||
self.joiner_net = joiner_net
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (T, C)
|
||||
states:
|
||||
A list of tensors. len(states) == self.num_layers * 4
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, a tensor of shape (T, encoder_dim).
|
||||
- next_states, a list of tensors containing the next states
|
||||
"""
|
||||
with self.encoder_net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
|
||||
for i in range(len(states)):
|
||||
name = f"in{i+1}"
|
||||
ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone())
|
||||
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
|
||||
out_states: List[torch.Tensor] = []
|
||||
for i in range(len(states)):
|
||||
name = f"out{i+1}"
|
||||
ret, ncnn_out_state = ex.extract(name)
|
||||
assert ret == 0, ret
|
||||
ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
|
||||
|
||||
if i < len(self.num_encoder_layers):
|
||||
# for cached_len, we need to discard the last dim
|
||||
ncnn_out_state = ncnn_out_state.squeeze(1)
|
||||
|
||||
out_states.append(ncnn_out_state)
|
||||
|
||||
return encoder_out, out_states
|
||||
|
||||
def run_decoder(self, decoder_input):
|
||||
assert decoder_input.dtype == torch.int32
|
||||
|
||||
with self.decoder_net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
return decoder_out
|
||||
|
||||
def run_joiner(self, encoder_out, decoder_out):
|
||||
with self.joiner_net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
return joiner_out
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Model,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
|
||||
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||
else:
|
||||
assert decoder_out.ndim == 1
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t]
|
||||
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
|
||||
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = Model(args)
|
||||
|
||||
sound_file = args.sound_filename
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
states = model.get_init_states()
|
||||
logging.info(f"number of states: {len(states)}")
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.T
|
||||
offset = model.chunk_length
|
||||
|
||||
chunk = int(1 * sample_rate) # 0.2 second
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
encoder_out, states = model.run_encoder(frames, states)
|
||||
hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
context_size = 2
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(sound_file)
|
||||
logging.info(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
|
@ -1,102 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script takes as input an FST in k2 format and convert it
|
||||
to an FST in OpenFST format.
|
||||
|
||||
The generated FST is saved into a binary file and its type is
|
||||
StdVectorFst.
|
||||
|
||||
Usage examples:
|
||||
(1) Convert an acceptor
|
||||
|
||||
./convert-k2-to-openfst.py in.pt binary.fst
|
||||
|
||||
(2) Convert a transducer
|
||||
|
||||
./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import kaldifst.utils
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--olabels",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""If not empty, the input FST is assumed to be a transducer
|
||||
and we use its attribute specified by "olabels" as the output labels.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"input_filename",
|
||||
type=str,
|
||||
help="Path to the input FST in k2 format",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"output_filename",
|
||||
type=str,
|
||||
help="Path to the output FST in OpenFst format",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(f"{vars(args)}")
|
||||
|
||||
input_filename = args.input_filename
|
||||
output_filename = args.output_filename
|
||||
olabels = args.olabels
|
||||
|
||||
if Path(output_filename).is_file():
|
||||
logging.info(f"{output_filename} already exists - skipping")
|
||||
return
|
||||
|
||||
assert Path(input_filename).is_file(), f"{input_filename} does not exist"
|
||||
logging.info(f"Loading {input_filename}")
|
||||
k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
|
||||
if olabels:
|
||||
assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
|
||||
|
||||
p = Path(output_filename).parent
|
||||
if not p.is_dir():
|
||||
logging.info(f"Creating {p}")
|
||||
p.mkdir(parents=True)
|
||||
|
||||
logging.info("Converting (May take some time if the input FST is large)")
|
||||
fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
|
||||
logging.info(f"Saving to {output_filename}")
|
||||
fst.write(output_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
icefall/shared/convert-k2-to-openfst.py
Symbolic link
1
icefall/shared/convert-k2-to-openfst.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/shared/convert-k2-to-openfst.py
|
@ -1,443 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2016 Johns Hopkins University (Author: Daniel Povey)
|
||||
# 2018 Ruizhe Huang
|
||||
# Apache 2.0.
|
||||
|
||||
# This is an implementation of computing Kneser-Ney smoothed language model
|
||||
# in the same way as srilm. This is a back-off, unmodified version of
|
||||
# Kneser-Ney smoothing, which produces the same results as the following
|
||||
# command (as an example) of srilm:
|
||||
#
|
||||
# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \
|
||||
# -text corpus.txt -lm lm.arpa
|
||||
#
|
||||
# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
|
||||
# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
Generate kneser-ney language model as arpa format. By default,
|
||||
it will read the corpus from standard input, and output to standard output.
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ngram-order",
|
||||
type=int,
|
||||
default=4,
|
||||
choices=[2, 3, 4, 5, 6, 7],
|
||||
help="Order of n-gram",
|
||||
)
|
||||
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
|
||||
parser.add_argument(
|
||||
"-lm", type=str, default=None, help="Path to output arpa file for language models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# For encoding-agnostic scripts, we assume byte stream as input.
|
||||
# Need to be very careful about the use of strip() and split()
|
||||
# in this case, because there is a latin-1 whitespace character
|
||||
# (nbsp) which is part of the unicode encoding range.
|
||||
# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
|
||||
default_encoding = "latin-1"
|
||||
|
||||
strip_chars = " \t\r\n"
|
||||
whitespace = re.compile("[ \t]+")
|
||||
|
||||
|
||||
class CountsForHistory:
|
||||
# This class (which is more like a struct) stores the counts seen in a
|
||||
# particular history-state. It is used inside class NgramCounts.
|
||||
# It really does the job of a dict from int to float, but it also
|
||||
# keeps track of the total count.
|
||||
def __init__(self):
|
||||
# The 'lambda: defaultdict(float)' is an anonymous function taking no
|
||||
# arguments that returns a new defaultdict(float).
|
||||
self.word_to_count = defaultdict(int)
|
||||
# using a set to count the number of unique contexts
|
||||
self.word_to_context = defaultdict(set)
|
||||
self.word_to_f = dict() # discounted probability
|
||||
self.word_to_bow = dict() # back-off weight
|
||||
self.total_count = 0
|
||||
|
||||
def words(self):
|
||||
return self.word_to_count.keys()
|
||||
|
||||
def __str__(self):
|
||||
# e.g. returns ' total=12: 3->4, 4->6, -1->2'
|
||||
return " total={0}: {1}".format(
|
||||
str(self.total_count),
|
||||
", ".join(
|
||||
[
|
||||
"{0} -> {1}".format(word, count)
|
||||
for word, count in self.word_to_count.items()
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def add_count(self, predicted_word, context_word, count):
|
||||
assert count >= 0
|
||||
|
||||
self.total_count += count
|
||||
self.word_to_count[predicted_word] += count
|
||||
if context_word is not None:
|
||||
self.word_to_context[predicted_word].add(context_word)
|
||||
|
||||
|
||||
class NgramCounts:
|
||||
# A note on data-structure. Firstly, all words are represented as
|
||||
# integers. We store n-gram counts as an array, indexed by (history-length
|
||||
# == n-gram order minus one) (note: python calls arrays "lists") of dicts
|
||||
# from histories to counts, where histories are arrays of integers and
|
||||
# "counts" are dicts from integer to float. For instance, when
|
||||
# accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
|
||||
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
|
||||
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
|
||||
def __init__(self, ngram_order, bos_symbol="<s>", eos_symbol="</s>"):
|
||||
assert ngram_order >= 2
|
||||
|
||||
self.ngram_order = ngram_order
|
||||
self.bos_symbol = bos_symbol
|
||||
self.eos_symbol = eos_symbol
|
||||
|
||||
self.counts = []
|
||||
for n in range(ngram_order):
|
||||
self.counts.append(defaultdict(lambda: CountsForHistory()))
|
||||
|
||||
self.d = [] # list of discounting factor for each order of ngram
|
||||
|
||||
# adds a raw count (called while processing input data).
|
||||
# Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history'
|
||||
# would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
|
||||
# 1.
|
||||
def add_count(self, history, predicted_word, context_word, count):
|
||||
self.counts[len(history)][history].add_count(
|
||||
predicted_word, context_word, count
|
||||
)
|
||||
|
||||
# 'line' is a string containing a sequence of integer word-ids.
|
||||
# This function adds the un-smoothed counts from this line of text.
|
||||
def add_raw_counts_from_line(self, line):
|
||||
if line == "":
|
||||
words = [self.bos_symbol, self.eos_symbol]
|
||||
else:
|
||||
words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
|
||||
|
||||
for i in range(len(words)):
|
||||
for n in range(1, self.ngram_order + 1):
|
||||
if i + n > len(words):
|
||||
break
|
||||
ngram = words[i : i + n]
|
||||
predicted_word = ngram[-1]
|
||||
history = tuple(ngram[:-1])
|
||||
if i == 0 or n == self.ngram_order:
|
||||
context_word = None
|
||||
else:
|
||||
context_word = words[i - 1]
|
||||
|
||||
self.add_count(history, predicted_word, context_word, 1)
|
||||
|
||||
def add_raw_counts_from_standard_input(self):
|
||||
lines_processed = 0
|
||||
# byte stream as input
|
||||
infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)
|
||||
for line in infile:
|
||||
line = line.strip(strip_chars)
|
||||
self.add_raw_counts_from_line(line)
|
||||
lines_processed += 1
|
||||
if lines_processed == 0 or args.verbose > 0:
|
||||
print(
|
||||
"make_phone_lm.py: processed {0} lines of input".format(
|
||||
lines_processed
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def add_raw_counts_from_file(self, filename):
|
||||
lines_processed = 0
|
||||
with open(filename, encoding=default_encoding) as fp:
|
||||
for line in fp:
|
||||
line = line.strip(strip_chars)
|
||||
self.add_raw_counts_from_line(line)
|
||||
lines_processed += 1
|
||||
if lines_processed == 0 or args.verbose > 0:
|
||||
print(
|
||||
"make_phone_lm.py: processed {0} lines of input".format(
|
||||
lines_processed
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def cal_discounting_constants(self):
|
||||
# For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
|
||||
# where n1_N is the number of unique N-grams with count = 1 (counts-of-counts).
|
||||
# This constant is used similarly to absolute discounting.
|
||||
# Return value: d is a list of floats, where d[N+1] = D_N
|
||||
|
||||
# for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
|
||||
# This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
|
||||
# but perhaps this is not the case for some other scenarios.
|
||||
self.d = [0]
|
||||
for n in range(1, self.ngram_order):
|
||||
this_order_counts = self.counts[n]
|
||||
n1 = 0
|
||||
n2 = 0
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
stat = Counter(counts_for_hist.word_to_count.values())
|
||||
n1 += stat[1]
|
||||
n2 += stat[2]
|
||||
assert n1 + 2 * n2 > 0
|
||||
|
||||
# We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
|
||||
# which could happen if the number of symbols is small.
|
||||
# Otherwise, zero discounting constant can cause division by zero in computing BOW.
|
||||
self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))
|
||||
|
||||
def cal_f(self):
|
||||
# f(a_z) is a probability distribution of word sequence a_z.
|
||||
# Typically f(a_z) is discounted to be less than the ML estimate so we have
|
||||
# some leftover probability for the z words unseen in the context (a_).
|
||||
#
|
||||
# f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams
|
||||
# f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams
|
||||
|
||||
# highest order N-grams
|
||||
n = self.ngram_order - 1
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w, c in counts_for_hist.word_to_count.items():
|
||||
counts_for_hist.word_to_f[w] = (
|
||||
max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
|
||||
)
|
||||
|
||||
# lower order N-grams
|
||||
for n in range(0, self.ngram_order - 1):
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
|
||||
n_star_star = 0
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
n_star_star += len(counts_for_hist.word_to_context[w])
|
||||
|
||||
if n_star_star != 0:
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
n_star_z = len(counts_for_hist.word_to_context[w])
|
||||
counts_for_hist.word_to_f[w] = (
|
||||
max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
|
||||
)
|
||||
else: # patterns begin with <s>, they do not have "modified count", so use raw count instead
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
n_star_z = counts_for_hist.word_to_count[w]
|
||||
counts_for_hist.word_to_f[w] = (
|
||||
max((n_star_z - self.d[n]), 0)
|
||||
* 1.0
|
||||
/ counts_for_hist.total_count
|
||||
)
|
||||
|
||||
def cal_bow(self):
|
||||
# Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
|
||||
# Thus, two sorts of ngrams do not have a bow:
|
||||
# 1) highest order ngram
|
||||
# 2) ngrams ending in </s>
|
||||
#
|
||||
# bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z))
|
||||
# Note that Z1 is the set of all words with c(a_z) > 0
|
||||
|
||||
# highest order N-grams
|
||||
n = self.ngram_order - 1
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
counts_for_hist.word_to_bow[w] = None
|
||||
|
||||
# lower order N-grams
|
||||
for n in range(0, self.ngram_order - 1):
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
if w == self.eos_symbol:
|
||||
counts_for_hist.word_to_bow[w] = None
|
||||
else:
|
||||
a_ = hist + (w,)
|
||||
|
||||
assert len(a_) < self.ngram_order
|
||||
assert a_ in self.counts[len(a_)].keys()
|
||||
|
||||
a_counts_for_hist = self.counts[len(a_)][a_]
|
||||
|
||||
sum_z1_f_a_z = 0
|
||||
for u in a_counts_for_hist.word_to_count.keys():
|
||||
sum_z1_f_a_z += a_counts_for_hist.word_to_f[u]
|
||||
|
||||
sum_z1_f_z = 0
|
||||
_ = a_[1:]
|
||||
_counts_for_hist = self.counts[len(_)][_]
|
||||
# Should be careful here: what is Z1
|
||||
for u in a_counts_for_hist.word_to_count.keys():
|
||||
sum_z1_f_z += _counts_for_hist.word_to_f[u]
|
||||
|
||||
if sum_z1_f_z < 1:
|
||||
# assert sum_z1_f_a_z < 1
|
||||
counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
|
||||
1.0 - sum_z1_f_z
|
||||
)
|
||||
else:
|
||||
counts_for_hist.word_to_bow[w] = None
|
||||
|
||||
def print_raw_counts(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
res.append(
|
||||
"{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
|
||||
)
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_modified_counts(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
modified_count = len(counts_for_hist.word_to_context[w])
|
||||
raw_count = counts_for_hist.word_to_count[w]
|
||||
|
||||
if modified_count == 0:
|
||||
res.append("{0}\t{1}".format(ngram, raw_count))
|
||||
else:
|
||||
res.append("{0}\t{1}".format(ngram, modified_count))
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_f(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
f = counts_for_hist.word_to_f[w]
|
||||
if f == 0: # f(<s>) is always 0
|
||||
f = 1e-99
|
||||
|
||||
res.append("{0}\t{1}".format(ngram, math.log(f, 10)))
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_f_and_bow(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
f = counts_for_hist.word_to_f[w]
|
||||
if f == 0: # f(<s>) is always 0
|
||||
f = 1e-99
|
||||
|
||||
bow = counts_for_hist.word_to_bow[w]
|
||||
if bow is None:
|
||||
res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
|
||||
else:
|
||||
res.append(
|
||||
"{1}\t{0}\t{2}".format(
|
||||
ngram, math.log(f, 10), math.log(bow, 10)
|
||||
)
|
||||
)
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_as_arpa(
|
||||
self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
|
||||
):
|
||||
# print as ARPA format.
|
||||
|
||||
print("\\data\\", file=fout)
|
||||
for hist_len in range(self.ngram_order):
|
||||
# print the number of n-grams.
|
||||
print(
|
||||
"ngram {0}={1}".format(
|
||||
hist_len + 1,
|
||||
sum(
|
||||
[
|
||||
len(counts_for_hist.word_to_f)
|
||||
for counts_for_hist in self.counts[hist_len].values()
|
||||
]
|
||||
),
|
||||
),
|
||||
file=fout,
|
||||
)
|
||||
|
||||
print("", file=fout)
|
||||
|
||||
for hist_len in range(self.ngram_order):
|
||||
print("\\{0}-grams:".format(hist_len + 1), file=fout)
|
||||
|
||||
this_order_counts = self.counts[hist_len]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for word in counts_for_hist.word_to_count.keys():
|
||||
ngram = hist + (word,)
|
||||
prob = counts_for_hist.word_to_f[word]
|
||||
bow = counts_for_hist.word_to_bow[word]
|
||||
|
||||
if prob == 0: # f(<s>) is always 0
|
||||
prob = 1e-99
|
||||
|
||||
line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
|
||||
if bow is not None:
|
||||
line += "\t{0}".format("%.7f" % math.log10(bow))
|
||||
print(line, file=fout)
|
||||
print("", file=fout)
|
||||
print("\\end\\", file=fout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
ngram_counts = NgramCounts(args.ngram_order)
|
||||
if args.text is None:
|
||||
ngram_counts.add_raw_counts_from_standard_input()
|
||||
else:
|
||||
assert os.path.isfile(args.text)
|
||||
ngram_counts.add_raw_counts_from_file(args.text)
|
||||
|
||||
ngram_counts.cal_discounting_constants()
|
||||
ngram_counts.cal_f()
|
||||
ngram_counts.cal_bow()
|
||||
|
||||
if args.lm is None:
|
||||
ngram_counts.print_as_arpa()
|
||||
else:
|
||||
with open(args.lm, "w", encoding=default_encoding) as f:
|
||||
ngram_counts.print_as_arpa(fout=f)
|
1
icefall/shared/make_kn_lm.py
Symbolic link
1
icefall/shared/make_kn_lm.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/shared/make_kn_lm.py
|
@ -1,630 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
./ngram_entropy_pruning.py \
|
||||
-threshold 1e-8 \
|
||||
-lm download/lm/4gram.arpa \
|
||||
-write-lm download/lm/4gram_pruned_1e8.arpa
|
||||
|
||||
This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
|
||||
This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
|
||||
in the same way as SRILM.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum, unique
|
||||
from io import StringIO
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
Prune an n-gram language model based on the relative entropy
|
||||
between the original and the pruned model, based on Andreas Stolcke's paper.
|
||||
An n-gram entry is removed, if the removal causes (training set) perplexity
|
||||
of the model to increase by less than threshold relative.
|
||||
|
||||
The command takes an arpa file and a pruning threshold as input,
|
||||
and outputs a pruned arpa file.
|
||||
"""
|
||||
)
|
||||
parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
|
||||
parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
|
||||
parser.add_argument(
|
||||
"-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-minorder",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The minorder parameter limits pruning to ngrams of that length and above.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-verbose",
|
||||
type=int,
|
||||
default=2,
|
||||
choices=[0, 1, 2, 3, 4, 5],
|
||||
help="Verbose level, where 0 is most noisy; 5 is most silent",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
default_encoding = args.encoding
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
|
||||
level=args.verbose * 10,
|
||||
)
|
||||
|
||||
|
||||
class Context(dict):
|
||||
"""
|
||||
This class stores data for a context h.
|
||||
It behaves like a python dict object, except that it has several
|
||||
additional attributes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.log_bo = None
|
||||
|
||||
|
||||
class Arpa:
|
||||
"""
|
||||
This is a class that implement the data structure of an APRA LM.
|
||||
It (as well as some other classes) is modified based on the library
|
||||
by Stefan Fischer:
|
||||
https://github.com/sfischer13/python-arpa
|
||||
"""
|
||||
|
||||
UNK = "<unk>"
|
||||
SOS = "<s>"
|
||||
EOS = "</s>"
|
||||
FLOAT_NDIGITS = 7
|
||||
base = 10
|
||||
|
||||
@staticmethod
|
||||
def _check_input(my_input):
|
||||
if not my_input:
|
||||
raise ValueError
|
||||
elif isinstance(my_input, tuple):
|
||||
return my_input
|
||||
elif isinstance(my_input, list):
|
||||
return tuple(my_input)
|
||||
elif isinstance(my_input, str):
|
||||
return tuple(my_input.strip().split(" "))
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _check_word(input_word):
|
||||
if not isinstance(input_word, str):
|
||||
raise ValueError
|
||||
if " " in input_word:
|
||||
raise ValueError
|
||||
|
||||
def _replace_unks(self, words):
|
||||
return tuple((w if w in self else self._unk) for w in words)
|
||||
|
||||
def __init__(self, path=None, encoding=None, unk=None):
|
||||
self._counts = OrderedDict()
|
||||
self._ngrams = (
|
||||
OrderedDict()
|
||||
) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
|
||||
self._vocabulary = set()
|
||||
if unk is None:
|
||||
self._unk = self.UNK
|
||||
|
||||
if path is not None:
|
||||
self.loadf(path, encoding)
|
||||
|
||||
def __contains__(self, ngram):
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
|
||||
|
||||
def contains_word(self, word):
|
||||
self._check_word(word)
|
||||
return word in self._vocabulary
|
||||
|
||||
def add_count(self, order, count):
|
||||
self._counts[order] = count
|
||||
self._ngrams[order - 1] = defaultdict(Context)
|
||||
|
||||
def update_counts(self):
|
||||
for order in range(1, self.order() + 1):
|
||||
count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
|
||||
if count > 0:
|
||||
self._counts[order] = count
|
||||
|
||||
def add_entry(self, ngram, p, bo=None, order=None):
|
||||
# Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
|
||||
# Note that p and bo here are in fact in the log domain (self.base = 10)
|
||||
h_context = self._ngrams[len(h)][h]
|
||||
h_context[w] = p
|
||||
if bo is not None:
|
||||
self._ngrams[len(ngram)][ngram].log_bo = bo
|
||||
|
||||
for word in ngram:
|
||||
self._vocabulary.add(word)
|
||||
|
||||
def counts(self):
|
||||
return sorted(self._counts.items())
|
||||
|
||||
def order(self):
|
||||
return max(self._counts.keys(), default=None)
|
||||
|
||||
def vocabulary(self, sort=True):
|
||||
if sort:
|
||||
return sorted(self._vocabulary)
|
||||
else:
|
||||
return self._vocabulary
|
||||
|
||||
def _entries(self, order):
|
||||
return (
|
||||
self._entry(h, w)
|
||||
for h, wlist in self._ngrams[order - 1].items()
|
||||
for w in wlist
|
||||
)
|
||||
|
||||
def _entry(self, h, w):
|
||||
# return the entry for the ngram (h, w)
|
||||
ngram = h + (w,)
|
||||
log_p = self._ngrams[len(h)][h][w]
|
||||
log_bo = self._log_bo(ngram)
|
||||
if log_bo is not None:
|
||||
return (
|
||||
round(log_p, self.FLOAT_NDIGITS),
|
||||
ngram,
|
||||
round(log_bo, self.FLOAT_NDIGITS),
|
||||
)
|
||||
else:
|
||||
return round(log_p, self.FLOAT_NDIGITS), ngram
|
||||
|
||||
def _log_bo(self, ngram):
|
||||
if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
|
||||
return self._ngrams[len(ngram)][ngram].log_bo
|
||||
else:
|
||||
return None
|
||||
|
||||
def _log_p(self, ngram):
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
|
||||
return self._ngrams[len(h)][h][w]
|
||||
else:
|
||||
return None
|
||||
|
||||
def log_p_raw(self, ngram):
|
||||
log_p = self._log_p(ngram)
|
||||
if log_p is not None:
|
||||
return log_p
|
||||
else:
|
||||
if len(ngram) == 1:
|
||||
raise KeyError
|
||||
else:
|
||||
log_bo = self._log_bo(ngram[:-1])
|
||||
if log_bo is None:
|
||||
log_bo = 0
|
||||
return log_bo + self.log_p_raw(ngram[1:])
|
||||
|
||||
def log_joint_prob(self, sequence):
|
||||
# Compute the joint prob of the sequence based on the chain rule
|
||||
# Note that sequence should be a tuple of strings
|
||||
#
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
|
||||
|
||||
log_joint_p = 0
|
||||
seq = sequence
|
||||
while len(seq) > 0:
|
||||
log_joint_p += self.log_p_raw(seq)
|
||||
seq = seq[:-1]
|
||||
|
||||
# If we're computing the marginal probability of the unigram
|
||||
# <s> context we have to look up </s> instead since the former
|
||||
# has prob = 0.
|
||||
if len(seq) == 1 and seq[0] == self.SOS:
|
||||
seq = (self.EOS,)
|
||||
|
||||
return log_joint_p
|
||||
|
||||
def set_new_context(self, h):
|
||||
old_context = self._ngrams[len(h)][h]
|
||||
self._ngrams[len(h)][h] = Context()
|
||||
return old_context
|
||||
|
||||
def log_p(self, ngram):
|
||||
words = self._check_input(ngram)
|
||||
if self._unk:
|
||||
words = self._replace_unks(words)
|
||||
return self.log_p_raw(words)
|
||||
|
||||
def log_s(self, sentence, sos=SOS, eos=EOS):
|
||||
words = self._check_input(sentence)
|
||||
if self._unk:
|
||||
words = self._replace_unks(words)
|
||||
if sos:
|
||||
words = (sos,) + words
|
||||
if eos:
|
||||
words = words + (eos,)
|
||||
result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
|
||||
if sos:
|
||||
result = result - self.log_p_raw(words[:1])
|
||||
return result
|
||||
|
||||
def p(self, ngram):
|
||||
return self.base ** self.log_p(ngram)
|
||||
|
||||
def s(self, sentence):
|
||||
return self.base ** self.log_s(sentence)
|
||||
|
||||
def write(self, fp):
|
||||
fp.write("\n\\data\\\n")
|
||||
for order, count in self.counts():
|
||||
fp.write("ngram {}={}\n".format(order, count))
|
||||
fp.write("\n")
|
||||
for order, _ in self.counts():
|
||||
fp.write("\\{}-grams:\n".format(order))
|
||||
for e in self._entries(order):
|
||||
prob = e[0]
|
||||
ngram = " ".join(e[1])
|
||||
if len(e) == 2:
|
||||
fp.write("{}\t{}\n".format(prob, ngram))
|
||||
elif len(e) == 3:
|
||||
backoff = e[2]
|
||||
fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
|
||||
else:
|
||||
raise ValueError
|
||||
fp.write("\n")
|
||||
fp.write("\\end\\\n")
|
||||
|
||||
|
||||
class ArpaParser:
|
||||
"""
|
||||
This is a class that implement a parser of an arpa file
|
||||
"""
|
||||
|
||||
@unique
|
||||
class State(Enum):
|
||||
DATA = 1
|
||||
COUNT = 2
|
||||
HEADER = 3
|
||||
ENTRY = 4
|
||||
|
||||
re_count = re.compile(r"^ngram (\d+)=(\d+)$")
|
||||
re_header = re.compile(r"^\\(\d+)-grams:$")
|
||||
re_entry = re.compile(
|
||||
"^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
|
||||
"\t"
|
||||
"(\\S+( \\S+)*)"
|
||||
"(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
|
||||
)
|
||||
|
||||
def _parse(self, fp):
|
||||
self._result = []
|
||||
self._state = self.State.DATA
|
||||
self._tmp_model = None
|
||||
self._tmp_order = None
|
||||
for line in fp:
|
||||
line = line.strip()
|
||||
if self._state == self.State.DATA:
|
||||
self._data(line)
|
||||
elif self._state == self.State.COUNT:
|
||||
self._count(line)
|
||||
elif self._state == self.State.HEADER:
|
||||
self._header(line)
|
||||
elif self._state == self.State.ENTRY:
|
||||
self._entry(line)
|
||||
if self._state != self.State.DATA:
|
||||
raise Exception(line)
|
||||
return self._result
|
||||
|
||||
def _data(self, line):
|
||||
if line == "\\data\\":
|
||||
self._state = self.State.COUNT
|
||||
self._tmp_model = Arpa()
|
||||
else:
|
||||
pass # skip comment line
|
||||
|
||||
def _count(self, line):
|
||||
match = self.re_count.match(line)
|
||||
if match:
|
||||
order = match.group(1)
|
||||
count = match.group(2)
|
||||
self._tmp_model.add_count(int(order), int(count))
|
||||
elif not line:
|
||||
self._state = self.State.HEADER # there are no counts
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
def _header(self, line):
|
||||
match = self.re_header.match(line)
|
||||
if match:
|
||||
self._state = self.State.ENTRY
|
||||
self._tmp_order = int(match.group(1))
|
||||
elif line == "\\end\\":
|
||||
self._result.append(self._tmp_model)
|
||||
self._state = self.State.DATA
|
||||
self._tmp_model = None
|
||||
self._tmp_order = None
|
||||
elif not line:
|
||||
pass # skip empty line
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
def _entry(self, line):
|
||||
match = self.re_entry.match(line)
|
||||
if match:
|
||||
p = self._float_or_int(match.group(1))
|
||||
ngram = tuple(match.group(4).split(" "))
|
||||
bo_match = match.group(7)
|
||||
bo = self._float_or_int(bo_match) if bo_match else None
|
||||
self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
|
||||
elif not line:
|
||||
self._state = self.State.HEADER # last entry
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
@staticmethod
|
||||
def _float_or_int(s):
|
||||
f = float(s)
|
||||
i = int(f)
|
||||
if str(i) == s: # don't drop trailing ".0"
|
||||
return i
|
||||
else:
|
||||
return f
|
||||
|
||||
def load(self, fp):
|
||||
"""Deserialize fp (a file-like object) to a Python object."""
|
||||
return self._parse(fp)
|
||||
|
||||
def loadf(self, path, encoding=None):
|
||||
"""Deserialize path (.arpa, .gz) to a Python object."""
|
||||
path = str(path)
|
||||
if path.endswith(".gz"):
|
||||
with gzip.open(path, mode="rt", encoding=encoding) as f:
|
||||
return self.load(f)
|
||||
else:
|
||||
with open(path, mode="rt", encoding=encoding) as f:
|
||||
return self.load(f)
|
||||
|
||||
def loads(self, s):
|
||||
"""Deserialize s (a str) to a Python object."""
|
||||
with StringIO(s) as f:
|
||||
return self.load(f)
|
||||
|
||||
def dump(self, obj, fp):
|
||||
"""Serialize obj to fp (a file-like object) in ARPA format."""
|
||||
obj.write(fp)
|
||||
|
||||
def dumpf(self, obj, path, encoding=None):
|
||||
"""Serialize obj to path in ARPA format (.arpa, .gz)."""
|
||||
path = str(path)
|
||||
if path.endswith(".gz"):
|
||||
with gzip.open(path, mode="wt", encoding=encoding) as f:
|
||||
return self.dump(obj, f)
|
||||
else:
|
||||
with open(path, mode="wt", encoding=encoding) as f:
|
||||
self.dump(obj, f)
|
||||
|
||||
def dumps(self, obj):
|
||||
"""Serialize obj to an ARPA formatted str."""
|
||||
with StringIO() as f:
|
||||
self.dump(obj, f)
|
||||
return f.getvalue()
|
||||
|
||||
|
||||
def add_log_p(prev_log_sum, log_p, base):
|
||||
return math.log(base**log_p + base**prev_log_sum, base)
|
||||
|
||||
|
||||
def compute_numerator_denominator(lm, h):
|
||||
log_sum_seen_h = -math.inf
|
||||
log_sum_seen_h_lower = -math.inf
|
||||
base = lm.base
|
||||
for w, log_p in lm._ngrams[len(h)][h].items():
|
||||
log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
|
||||
|
||||
ngram = h + (w,)
|
||||
log_p_lower = lm.log_p_raw(ngram[1:])
|
||||
log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
|
||||
|
||||
numerator = 1.0 - base**log_sum_seen_h
|
||||
denominator = 1.0 - base**log_sum_seen_h_lower
|
||||
return numerator, denominator
|
||||
|
||||
|
||||
def prune(lm, threshold, minorder):
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
|
||||
|
||||
for i in range(
|
||||
lm.order(), max(minorder - 1, 1), -1
|
||||
): # i is the order of the ngram (h, w)
|
||||
logging.info("processing %d-grams ..." % i)
|
||||
count_pruned_ngrams = 0
|
||||
|
||||
h_dict = lm._ngrams[i - 1]
|
||||
for h in list(h_dict.keys()):
|
||||
# old backoff weight, BOW(h)
|
||||
log_bow = lm._log_bo(h)
|
||||
if log_bow is None:
|
||||
log_bow = 0
|
||||
|
||||
# Compute numerator and denominator of the backoff weight,
|
||||
# so that we can quickly compute the BOW adjustment due to
|
||||
# leaving out one prob.
|
||||
numerator, denominator = compute_numerator_denominator(lm, h)
|
||||
|
||||
# assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
|
||||
|
||||
# Compute the marginal probability of the context, P(h)
|
||||
h_log_p = lm.log_joint_prob(h)
|
||||
|
||||
all_pruned = True
|
||||
pruned_w_set = set()
|
||||
|
||||
for w, log_p in h_dict[h].items():
|
||||
ngram = h + (w,)
|
||||
|
||||
# lower-order estimate for ngramProb, P(w|h')
|
||||
backoff_prob = lm.log_p_raw(ngram[1:])
|
||||
|
||||
# Compute BOW after removing ngram, BOW'(h)
|
||||
new_log_bow = math.log(
|
||||
numerator + lm.base**log_p, lm.base
|
||||
) - math.log(denominator + lm.base**backoff_prob, lm.base)
|
||||
|
||||
# Compute change in entropy due to removal of ngram
|
||||
delta_prob = backoff_prob + new_log_bow - log_p
|
||||
delta_entropy = -(lm.base**h_log_p) * (
|
||||
(lm.base**log_p) * delta_prob
|
||||
+ numerator * (new_log_bow - log_bow)
|
||||
)
|
||||
|
||||
# compute relative change in model (training set) perplexity
|
||||
perp_change = lm.base**delta_entropy - 1.0
|
||||
|
||||
pruned = threshold > 0 and perp_change < threshold
|
||||
|
||||
# Make sure we don't prune ngrams whose backoff nodes are needed
|
||||
if (
|
||||
pruned
|
||||
and len(ngram) in lm._ngrams
|
||||
and len(lm._ngrams[len(ngram)][ngram]) > 0
|
||||
):
|
||||
pruned = False
|
||||
|
||||
logging.debug(
|
||||
"CONTEXT "
|
||||
+ str(h)
|
||||
+ " WORD "
|
||||
+ w
|
||||
+ " CONTEXTPROB %f " % h_log_p
|
||||
+ " OLDPROB %f " % log_p
|
||||
+ " NEWPROB %f " % (backoff_prob + new_log_bow)
|
||||
+ " DELTA-H %f " % delta_entropy
|
||||
+ " DELTA-LOGP %f " % delta_prob
|
||||
+ " PPL-CHANGE %f " % perp_change
|
||||
+ " PRUNED "
|
||||
+ str(pruned)
|
||||
)
|
||||
|
||||
if pruned:
|
||||
pruned_w_set.add(w)
|
||||
count_pruned_ngrams += 1
|
||||
else:
|
||||
all_pruned = False
|
||||
|
||||
# If we removed all ngrams for this context we can
|
||||
# remove the context itself, but only if the present
|
||||
# context is not a prefix to a longer one.
|
||||
if all_pruned and len(pruned_w_set) == len(h_dict[h]):
|
||||
del h_dict[
|
||||
h
|
||||
] # this context h is no longer needed, as its ngram prob is stored at its own context h'
|
||||
elif len(pruned_w_set) > 0:
|
||||
# The pruning for this context h is actually done here
|
||||
old_context = lm.set_new_context(h)
|
||||
|
||||
for w, p_w in old_context.items():
|
||||
if w not in pruned_w_set:
|
||||
lm.add_entry(
|
||||
h + (w,), p_w
|
||||
) # the entry hw is stored at the context h
|
||||
|
||||
# We need to recompute the back-off weight, but
|
||||
# this can only be done after completing the pruning
|
||||
# of the lower-order ngrams.
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
|
||||
|
||||
logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
|
||||
|
||||
# recompute backoff weights
|
||||
for i in range(
|
||||
max(minorder - 1, 1) + 1, lm.order() + 1
|
||||
): # be careful of this order: from low- to high-order
|
||||
for h in lm._ngrams[i - 1]:
|
||||
numerator, denominator = compute_numerator_denominator(lm, h)
|
||||
new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
|
||||
lm._ngrams[len(h)][h].log_bo = new_log_bow
|
||||
|
||||
# update counts
|
||||
lm.update_counts()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def check_h_is_valid(lm, h):
|
||||
sum_under_h = sum(
|
||||
[lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
|
||||
)
|
||||
if abs(sum_under_h - 1.0) > 1e-6:
|
||||
logging.info("warning: %s %f" % (str(h), sum_under_h))
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def validate_lm(lm):
|
||||
# sanity check if the conditional probability sums to one under each context h
|
||||
for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
|
||||
logging.info("validating %d-grams ..." % i)
|
||||
h_dict = lm._ngrams[i - 1]
|
||||
for h in h_dict.keys():
|
||||
check_h_is_valid(lm, h)
|
||||
|
||||
|
||||
def compare_two_apras(path1, path2):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load an arpa file
|
||||
logging.info("Loading the arpa file from %s" % args.lm)
|
||||
parser = ArpaParser()
|
||||
models = parser.loadf(args.lm, encoding=default_encoding)
|
||||
lm = models[0] # ARPA files may contain several models.
|
||||
logging.info("Stats before pruning:")
|
||||
for i, cnt in lm.counts():
|
||||
logging.info("ngram %d=%d" % (i, cnt))
|
||||
|
||||
# prune it, the language model will be modified in-place
|
||||
logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
|
||||
prune(lm, args.threshold, args.minorder)
|
||||
|
||||
# validate_lm(lm)
|
||||
|
||||
# write the arpa language model to a file
|
||||
logging.info("Stats after pruning:")
|
||||
for i, cnt in lm.counts():
|
||||
logging.info("ngram %d=%d" % (i, cnt))
|
||||
logging.info("Saving the pruned arpa file to %s" % args.write_lm)
|
||||
parser.dumpf(lm, args.write_lm, encoding=default_encoding)
|
||||
logging.info("Done.")
|
1
icefall/shared/ngram_entropy_pruning.py
Symbolic link
1
icefall/shared/ngram_entropy_pruning.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/shared/ngram_entropy_pruning.py
|
@ -1,97 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
||||
# Arnab Ghoshal, Karel Vesely
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parse command-line options.
|
||||
# To be sourced by another script (as in ". parse_options.sh").
|
||||
# Option format is: --option-name arg
|
||||
# and shell variable "option_name" gets set to value "arg."
|
||||
# The exception is --help, which takes no arguments, but prints the
|
||||
# $help_message variable (if defined).
|
||||
|
||||
|
||||
###
|
||||
### The --config file options have lower priority to command line
|
||||
### options, so we need to import them first...
|
||||
###
|
||||
|
||||
# Now import all the configs specified by command-line, in left-to-right order
|
||||
for ((argpos=1; argpos<$#; argpos++)); do
|
||||
if [ "${!argpos}" == "--config" ]; then
|
||||
argpos_plus1=$((argpos+1))
|
||||
config=${!argpos_plus1}
|
||||
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
||||
. $config # source the config file.
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###
|
||||
### Now we process the command line options
|
||||
###
|
||||
while true; do
|
||||
[ -z "${1:-}" ] && break; # break if there are no arguments
|
||||
case "$1" in
|
||||
# If the enclosing script is called with --help option, print the help
|
||||
# message and exit. Scripts should put help messages in $help_message
|
||||
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
||||
else printf "$help_message\n" 1>&2 ; fi;
|
||||
exit 0 ;;
|
||||
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
||||
exit 1 ;;
|
||||
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
||||
# then work out the variable name as $name, which will equal "foo_bar".
|
||||
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
||||
# Next we test whether the variable in question is undefned-- if so it's
|
||||
# an invalid option and we die. Note: $0 evaluates to the name of the
|
||||
# enclosing script.
|
||||
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
||||
# is undefined. We then have to wrap this test inside "eval" because
|
||||
# foo_bar is itself inside a variable ($name).
|
||||
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
||||
|
||||
oldval="`eval echo \\$$name`";
|
||||
# Work out whether we seem to be expecting a Boolean argument.
|
||||
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
||||
was_bool=true;
|
||||
else
|
||||
was_bool=false;
|
||||
fi
|
||||
|
||||
# Set the variable to the right value-- the escaped quotes make it work if
|
||||
# the option had spaces, like --cmd "queue.pl -sync y"
|
||||
eval $name=\"$2\";
|
||||
|
||||
# Check that Boolean-valued arguments are really Boolean.
|
||||
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
||||
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
shift 2;
|
||||
;;
|
||||
*) break;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# Check for an empty argument to the --cmd option, which can easily occur as a
|
||||
# result of scripting errors.
|
||||
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
||||
|
||||
|
||||
true; # so this script returns exit code 0.
|
1
icefall/shared/parse_options.sh
Symbolic link
1
icefall/shared/parse_options.sh
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/shared/parse_options.sh
|
Loading…
x
Reference in New Issue
Block a user