This commit is contained in:
JinZr 2023-07-21 00:50:01 +08:00
parent ffb0e7891d
commit 748db76648
29 changed files with 14523 additions and 547 deletions

View File

@ -120,7 +120,9 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
cd ../..
@ -236,12 +238,11 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
cd data/fbank
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_raw.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_raw.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_M_raw.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_S_raw.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING_raw.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET_raw.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_M.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_S.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
cd ../..
else
log "Abort! Please run ../../wenetspeech/ASR/prepare.sh"
@ -324,4 +325,4 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
done
./local/train_bpe_model.py --lang-dir ./data/lang_bpe_${vocab_size}
fi
fi

View File

@ -1,4 +1,5 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -20,18 +21,11 @@ import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
load_manifest,
load_manifest_lazy,
set_caching_enabled,
)
from lhotse.dataset import (
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
@ -40,7 +34,10 @@ from lhotse.dataset import (
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
@ -55,12 +52,13 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id)
class WenetSpeechAsrDataModule:
class LibriSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
@ -68,6 +66,7 @@ class WenetSpeechAsrDataModule:
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
@ -83,6 +82,20 @@ class WenetSpeechAsrDataModule:
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""",
)
group.add_argument(
"--mini-libri",
type=str2bool,
default=False,
help="True for mini librispeech",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -147,6 +160,12 @@ class WenetSpeechAsrDataModule:
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
@ -190,10 +209,10 @@ class WenetSpeechAsrDataModule:
)
group.add_argument(
"--training-subset",
"--input-strategy",
type=str,
default="L",
help="The training subset for using",
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
@ -208,12 +227,11 @@ class WenetSpeechAsrDataModule:
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
@ -262,6 +280,7 @@ class WenetSpeechAsrDataModule:
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
@ -292,8 +311,7 @@ class WenetSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=300000,
drop_last=True,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
@ -304,6 +322,10 @@ class WenetSpeechAsrDataModule:
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
@ -318,10 +340,6 @@ class WenetSpeechAsrDataModule:
worker_init_fn=worker_init_fn,
)
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_dl.sampler.load_state_dict(sampler_state_dict)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
@ -345,30 +363,28 @@ class WenetSpeechAsrDataModule:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
batch_size=None,
sampler=valid_sampler,
num_workers=self.args.num_workers,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
@ -376,7 +392,7 @@ class WenetSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
@ -386,24 +402,74 @@ class WenetSpeechAsrDataModule:
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest_lazy(
self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def test_net_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_NET cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def test_meeting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETING cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -1 +0,0 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao
# Mingshuang Luo)
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -21,53 +20,59 @@
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 35 \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
(2) beam search (not recommended)
./zipformer/decode.py \
--epoch 35 \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search (trivial_graph)
(4) fast beam search (one best)
./zipformer/decode.py \
--epoch 35 \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(4) fast beam search (LG)
(5) fast beam search (nbest)
./zipformer/decode.py \
--epoch 30 \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(5) fast beam search (nbest oracle WER)
(6) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 35 \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
@ -75,6 +80,17 @@ Usage:
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
@ -86,9 +102,10 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@ -99,10 +116,8 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_params, get_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -173,10 +188,17 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
@ -186,11 +208,13 @@ def get_parser():
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
If you use fast_beam_search_LG, you have to specify
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
@ -212,7 +236,7 @@ def get_parser():
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
@ -222,27 +246,17 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -251,7 +265,7 @@ def get_parser():
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -259,9 +273,9 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
@ -275,7 +289,8 @@ def get_parser():
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest_oracle""",
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -283,19 +298,8 @@ def get_parser():
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
@ -306,9 +310,9 @@ def get_parser():
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -326,12 +330,16 @@ def decode_one_batch(
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
@ -358,13 +366,7 @@ def decode_one_batch(
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
hyps = []
@ -377,12 +379,11 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_LG":
hyp_tokens = fast_beam_search_one_best(
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -390,12 +391,25 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
ilme_scale=params.ilme_scale,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
@ -406,85 +420,81 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
hyps.append(sp.decode(hyp).split())
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search_" + key: hyps}
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_"
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}_" + key: hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
@ -494,8 +504,12 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
@ -520,15 +534,14 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
@ -536,7 +549,8 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -545,14 +559,16 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
@ -578,7 +594,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -596,7 +613,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -606,10 +623,11 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -635,15 +653,15 @@ def main():
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -657,14 +675,13 @@ def main():
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -673,9 +690,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -702,9 +719,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -752,8 +769,9 @@ def main():
model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
@ -761,47 +779,37 @@ def main():
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
librispeech = LibriSpeechAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
dev_cuts = wenetspeech.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dls):
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -0,0 +1,148 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import k2
import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict
class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
cut_id: str,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
initial_states:
Initial decode states of the model, e.g. the return value of
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device:
The device to run this stream.
"""
if params.decoding_method == "fast_beam_search":
assert decoding_graph is not None
assert device == decoding_graph.device
self.params = params
self.cut_id = cut_id
self.LOG_EPS = math.log(1e-10)
self.states = initial_states
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
self.num_frames: int = 0
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
self.ground_truth: str = ""
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
# how many frames have been processed, at encoder output
self.done_frames: int = 0
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
decoding_graph
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
@property
def done(self) -> bool:
"""Return True if all the features are processed."""
return self._done
@property
def id(self) -> str:
return self.cut_id
def set_features(
self,
features: torch.Tensor,
tail_pad_len: int = 0,
) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length + tail_pad_len),
mode="constant",
value=self.LOG_EPS,
)
self.num_frames = self.features.size(0)
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
chunk_length = chunk_size + self.pad_length
ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames + ret_length # noqa
]
self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_features, ret_length
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.params.decoding_method == "greedy_search":
return self.hyp[self.params.context_size :] # noqa
elif self.params.decoding_method == "modified_beam_search":
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.params.context_size :] # noqa
else:
assert self.params.decoding_method == "fast_beam_search"
return self.hyp

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1,122 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out

View File

@ -1 +0,0 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1,776 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1-chunk-16-left-64.onnx
- decoder-epoch-99-avg-1-chunk-16-left-64.onnx
- joiner-epoch-99-avg-1-chunk-16-left-64.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
self.pad_length = 7 + 2 * 3
def forward(
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
N = x.size(0)
T = self.chunk_size * 2 + self.pad_length
x_lens = torch.tensor([T] * N, device=x.device)
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=x,
x_lens=x_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2)
encoder_states = states[:-2]
logging.info(f"len_encoder_states={len(encoder_states)}")
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, new_states
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
states.append(processed_lens)
return states
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
)
decode_chunk_len = encoder_model.chunk_size * 2
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
logging.info(f"len(init_state): {len(init_state)}")
inputs = {}
input_names = ["x"]
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
logging.info(f"{name}.shape: {tensors[0].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
logging.info(f"{name}.shape: {tensors[1].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
logging.info(f"{name}.shape: {tensors[2].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
logging.info(f"{name}.shape: {tensors[3].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
logging.info(f"{name}.shape: {tensors[4].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
logging.info(f"{name}.shape: {tensors[5].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
ds = encoder_model.encoder.downsampling_factor
left_context_len = encoder_model.left_context_len
left_context_len = [left_context_len // k for k in ds]
left_context_len = ",".join(map(str, left_context_len))
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "streaming zipformer2",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 32+7+2*3=45
"num_encoder_layers": num_encoder_layers,
"encoder_dims": encoder_dims,
"cnn_module_kernels": cnn_module_kernels,
"left_context_len": left_context_len,
"query_head_dims": query_head_dims,
"value_head_dims": value_head_dims,
"num_heads": num_heads,
}
logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6):
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
embed_states = init_state[-2]
name = "embed_states"
logging.info(f"{name}.shape: {embed_states.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size,)
processed_lens = init_state[-1]
name = "processed_lens"
logging.info(f"{name}.shape: {processed_lens.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
logging.info(inputs)
logging.info(outputs)
logging.info(input_names)
logging.info(output_names)
torch.onnx.export(
encoder_model,
(x, init_state),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"x": {0: "N"},
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
suffix += f"-chunk-{params.chunk_size}"
suffix += f"-left-{params.left_context_frames}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/export-onnx.py

View File

@ -0,0 +1,621 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming zipformer2",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1,545 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer/streaming_decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
import re
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/jit_pretrained.py

View File

@ -0,0 +1,280 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
Usage of this script:
./zipformer/jit_pretrained.py \
--nn-model-filename ./zipformer/exp/cpu_jit.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def greedy_search(
model: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = model.decoder.blank_id
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
features=features,
feature_lengths=feature_lengths,
)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
token_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
for filename, hyp in zip(args.sound_files, hyps):
words = token_ids_to_words(hyp)
s += f"{filename}:\n{words}\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py

View File

@ -0,0 +1,273 @@
#!/usr/bin/env python3
# flake8: noqa
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
Usage of this script:
./zipformer/jit_pretrained_streaming.py \
--nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
/path/to/foo.wav \
"""
import argparse
import logging
import math
from typing import List, Optional
import k2
import kaldifeat
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model jit_script.pt",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
device: torch.device = torch.device("cpu"),
):
assert encoder_out.ndim == 2
context_size = decoder.context_size
blank_id = decoder.blank_id
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0)
# decoder_input.shape (1,, 1 context_size)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
else:
assert decoder_out.ndim == 2
assert hyp is not None, hyp
T = encoder_out.size(0)
for i in range(T):
cur_encoder_out = encoder_out[i : i + 1]
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(
decoder_input, dtype=torch.int32, device=device
).unsqueeze(0)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
return hyp, decoder_out
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
encoder = model.encoder
decoder = model.decoder
joiner = model.joiner
token_table = k2.SymbolTable.from_file(args.tokens)
context_size = decoder.context_size
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor(args.sample_rate)
logging.info(f"Reading sound files: {args.sound_file}")
wave_samples = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=args.sample_rate,
)[0]
logging.info(wave_samples.shape)
logging.info("Decoding started")
chunk_length = encoder.chunk_size * 2
T = chunk_length + encoder.pad_length
logging.info(f"chunk_length: {chunk_length}")
logging.info(f"T: {T}")
states = encoder.get_init_states(device=device)
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
wave_samples = torch.cat([wave_samples, tail_padding])
chunk = int(0.25 * args.sample_rate) # 0.2 second
num_processed_frames = 0
hyp = None
decoder_out = None
start = 0
while start < wave_samples.numel():
logging.info(f"{start}/{wave_samples.numel()}")
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=args.sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= T:
frames = []
for i in range(T):
frames.append(online_fbank.get_frame(num_processed_frames + i))
frames = torch.cat(frames, dim=0).to(device).unsqueeze(0)
x_lens = torch.tensor([T], dtype=torch.int32, device=device)
encoder_out, out_lens, states = encoder(
features=frames,
feature_lengths=x_lens,
states=states,
)
num_processed_frames += chunk_length
hyp, decoder_out = greedy_search(
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device
)
text = ""
for i in hyp[context_size:]:
text += token_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
torch.set_num_threads(4)
torch.set_num_interop_threads(1)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1,66 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from scaling import ScaledLinear
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1,358 @@
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss

View File

@ -37,92 +37,193 @@ class MultiDataset:
- aishell4_cuts_train_M.jsonl.gz
- aishell4_cuts_train_S.jsonl.gz
- alimeeting-far_cuts_train.jsonl.gz
- cuts_L.jsonl.gz
- cuts_M.jsonl.gz
- cuts_S.jsonl.gz
- magicdata_cuts_train.jsonl.gz
- primewords_cuts_train.jsonl.gz
- stcmds_cuts_train.jsonl.gz
- thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L.jsonl.gz
- wenetspeech/cuts_M.jsonl.gz
- wenetspeech/cuts_S.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# LibriSpeech
logging.info("Loading LibriSpeech in lazy mode")
librispeech_cuts = load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
# THCHS-30
logging.info("Loading THCHS-30 in lazy mode")
thchs_30_cuts = load_manifest_lazy(
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
)
# GigaSpeech
filenames = glob.glob(f"{self.fbank_dir}/XL_split/cuts_XL.*.jsonl.gz")
# AISHELL-1
logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
)
pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
sorted_filenames = [f[1] for f in idx_filenames]
# AISHELL-4
logging.info("Loading Aishell-4 in lazy mode")
aishell_4_L_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
)
aishell_4_M_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
)
aishell_4_S_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
)
logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode")
# ST-CMDS
logging.info("Loading ST-CMDS in lazy mode")
stcmds_cuts = load_manifest_lazy(
self.fbank_dir / "stcmds_cuts_train.jsonl.gz"
)
gigaspeech_cuts = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
# Primewords
logging.info("Loading Primewords in lazy mode")
primewords_cuts = load_manifest_lazy(
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
)
# CommonVoice
logging.info("Loading CommonVoice in lazy mode")
commonvoice_cuts = load_manifest_lazy(
self.fbank_dir / f"cv-en_cuts_train.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData in lazy mode")
magicdata_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
)
# LibriHeavy
logging.info("Loading LibriHeavy in lazy mode")
libriheavy_small_cuts = load_manifest_lazy(
self.fbank_dir / "libriheavy_cuts_train_small.jsonl.gz"
)
libriheavy_medium_cuts = load_manifest_lazy(
self.fbank_dir / "libriheavy_cuts_train_medium.jsonl.gz"
)
libriheavy_cuts = lhotse.combine(libriheavy_small_cuts, libriheavy_medium_cuts)
# Aidatatang_200zh
logging.info("Loading Aidatatang_200zh in lazy mode")
aidatatang_200zh_cuts = load_manifest_lazy(
self.fbank_dir / "aidatatang_cuts_train.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting in lazy mode")
alimeeting_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
)
wenetspeech_M_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_M.jsonl.gz"
)
wenetspeech_S_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_S.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech in lazy mode")
kespeech_1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
)
kespeech_2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
)
return CutSet.mux(
librispeech_cuts,
gigaspeech_cuts,
commonvoice_cuts,
libriheavy_cuts,
weights=[
len(librispeech_cuts),
len(gigaspeech_cuts),
len(commonvoice_cuts),
len(libriheavy_cuts),
],
)
thchs_30_cuts,
aishell_cuts,
aishell_2_cuts,
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
aidatatang_200zh_cuts,
alimeeting_cuts,
wenetspeech_L_cuts,
wenetspeech_M_cuts,
wenetspeech_S_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
weights=[
len(thchs_30_cuts),
len(aishell_cuts),
len(aishell_2_cuts),
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(aidatatang_200zh_cuts),
len(alimeeting_cuts),
len(wenetspeech_L_cuts),
len(wenetspeech_M_cuts),
len(wenetspeech_S_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
],
)
def test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# GigaSpeech
logging.info("Loading GigaSpeech DEV in lazy mode")
gigaspeech_dev_cuts = load_manifest_lazy(self.fbank_dir / "cuts_DEV.jsonl.gz")
# Aidatatang_200zh
logging.info("Loading Aidatatang_200zh DEV set in lazy mode")
aidatatang_dev_cuts = load_manifest_lazy(self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz")
logging.info("Loading GigaSpeech TEST in lazy mode")
gigaspeech_test_cuts = load_manifest_lazy(self.fbank_dir / "cuts_TEST.jsonl.gz")
# AISHELL
logging.info("Loading Aishell DEV set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
# CommonVoice
logging.info("Loading CommonVoice DEV in lazy mode")
commonvoice_dev_cuts = load_manifest_lazy(
self.fbank_dir / "cv-en_cuts_dev.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 DEV set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting DEV set in lazy mode")
alimeeting_dev_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData DEV set in lazy mode")
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech DEV set in lazy mode")
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
)
logging.info("Loading CommonVoice TEST in lazy mode")
commonvoice_test_cuts = load_manifest_lazy(
self.fbank_dir / "cv-en_cuts_test.jsonl.gz"
)
return [
gigaspeech_dev_cuts,
gigaspeech_test_cuts,
commonvoice_dev_cuts,
commonvoice_test_cuts,
]
aidatatang_dev_cuts,
aishell_dev_cuts,
aishell2_dev_cuts,
alimeeting_dev_cuts,
magicdata_dev_cuts,
kespeech_dev_phase1_cuts,
kespeech_dev_phase2_cuts,
wenetspeech_dev_cuts,
]

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/onnx_check.py

View File

@ -0,0 +1,241 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model via torchscript (torch.jit.script())
./zipformer/export.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp/ \
--jit 1
It will generate the following file in $repo/exp:
- jit_script.pt
3. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
4. Run this file
./zipformer/onnx_check.py \
--jit-filename $repo/exp/jit_script.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
"""
import argparse
import logging
import torch
from onnx_pretrained import OnnxModel
from icefall import is_module_available
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the onnx encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the onnx decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the onnx joiner model",
)
return parser
def test_encoder(
torch_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
C = 80
for i in range(3):
N = torch.randint(low=1, high=20, size=(1,)).item()
T = torch.randint(low=30, high=50, size=(1,)).item()
logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
x = torch.rand(N, T, C)
x_lens = torch.randint(low=30, high=T + 1, size=(N,))
x_lens[0] = T
torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out)
onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens)
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), (
(torch_encoder_out - onnx_encoder_out).abs().max()
)
def test_decoder(
torch_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
context_size = onnx_model.context_size
vocab_size = onnx_model.vocab_size
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_decoder: iter {i}, N={N}")
x = torch.randint(
low=1,
high=vocab_size,
size=(N, context_size),
dtype=torch.int64,
)
torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False]))
torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out)
torch_decoder_out = torch_decoder_out.squeeze(1)
onnx_decoder_out = onnx_model.run_decoder(x)
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
(torch_decoder_out - onnx_decoder_out).abs().max()
)
def test_joiner(
torch_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1]
decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1]
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_joiner: iter {i}, N={N}")
encoder_out = torch.rand(N, encoder_dim)
decoder_out = torch.rand(N, decoder_dim)
projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out)
projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out)
torch_joiner_out = torch_model.joiner(encoder_out, decoder_out)
onnx_joiner_out = onnx_model.run_joiner(
projected_encoder_out, projected_decoder_out
)
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
(torch_joiner_out - onnx_joiner_out).abs().max()
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_model = torch.jit.load(args.jit_filename)
onnx_model = OnnxModel(
encoder_model_filename=args.onnx_encoder_filename,
decoder_model_filename=args.onnx_decoder_filename,
joiner_model_filename=args.onnx_joiner_filename,
)
logging.info("Test encoder")
test_encoder(torch_model, onnx_model)
logging.info("Test decoder")
test_decoder(torch_model, onnx_model)
logging.info("Test joiner")
test_joiner(torch_model, onnx_model)
logging.info("Finished checking ONNX models")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# See https://github.com/pytorch/pytorch/issues/38342
# and https://github.com/pytorch/pytorch/issues/33354
#
# If we don't do this, the delay increases whenever there is
# a new request that changes the actual batch size.
# If you use `py-spy dump --pid <server-pid> --native`, you will
# see a lot of time is spent in re-compiling the torch script model.
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -2,8 +2,7 @@
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang,
# Wei Kang)
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -22,47 +21,50 @@
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/wenetspeech/ASR
cd egs/librispeech/ASR
repo_url=https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_char/tokens.txt"
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-9999.pt
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_char/tokens.txt \
--epoch 9999 \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp/
--exp-dir $repo/exp \
--causal False
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
2. Run this file
./zipformer/onnx_decode.py \
--exp-dir ./zipformer/exp \
--exp-dir $repo/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
"""
@ -72,14 +74,14 @@ import time
from pathlib import Path
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut
from onnx_pretrained import OnnxModel, greedy_search
from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats
from k2 import SymbolTable
def get_parser():
@ -111,15 +113,14 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -133,7 +134,7 @@ def get_parser():
def decode_one_batch(
model: OnnxModel, token_table: k2.SymbolTable, batch: dict
model: OnnxModel, token_table: SymbolTable, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
@ -142,7 +143,7 @@ def decode_one_batch(
model:
The neural model.
token_table:
Mapping ids to tokens.
The token table.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -164,14 +165,20 @@ def decode_one_batch(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [[token_table[h] for h in hyp] for hyp in hyps]
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
hyps = [token_ids_to_words(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
token_table: k2.SymbolTable,
token_table: SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
@ -181,7 +188,7 @@ def decode_dataset(
model:
The neural model.
token_table:
Mapping ids to tokens.
The token table.
Returns:
- A list of tuples. Each tuple contains three elements:
@ -211,7 +218,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = list(ref_text)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
@ -256,7 +263,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
@ -270,8 +277,7 @@ def main():
device = torch.device("cpu")
logging.info(f"Device: {device}")
token_table = k2.SymbolTable.from_file(args.tokens)
assert token_table[0] == "<blk>"
token_table = SymbolTable.from_file(args.tokens)
logging.info(vars(args))
@ -284,37 +290,20 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
wenetspeech = WenetSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
dev_cuts = wenetspeech.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py

View File

@ -0,0 +1,544 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script loads ONNX models exported by ./export-onnx-streaming.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file with the exported ONNX models
./zipformer/onnx_pretrained-streaming.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
logging.info(f"encoder_meta={encoder_meta}")
model_type = encoder_meta["model_type"]
assert model_type == "zipformer2", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
num_encoder_layers = encoder_meta["num_encoder_layers"]
encoder_dims = encoder_meta["encoder_dims"]
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
left_context_len = encoder_meta["left_context_len"]
query_head_dims = encoder_meta["query_head_dims"]
value_head_dims = encoder_meta["value_head_dims"]
num_heads = encoder_meta["num_heads"]
def to_int_list(s):
return list(map(int, s.split(",")))
num_encoder_layers = to_int_list(num_encoder_layers)
encoder_dims = to_int_list(encoder_dims)
cnn_module_kernels = to_int_list(cnn_module_kernels)
left_context_len = to_int_list(left_context_len)
query_head_dims = to_int_list(query_head_dims)
value_head_dims = to_int_list(value_head_dims)
num_heads = to_int_list(num_heads)
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"encoder_dims: {encoder_dims}")
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
logging.info(f"left_context_len: {left_context_len}")
logging.info(f"query_head_dims: {query_head_dims}")
logging.info(f"value_head_dims: {value_head_dims}")
logging.info(f"num_heads: {num_heads}")
num_encoders = len(num_encoder_layers)
self.states = []
for i in range(num_encoders):
num_layers = num_encoder_layers[i]
key_dim = query_head_dims[i] * num_heads[i]
embed_dim = encoder_dims[i]
nonlin_attn_head_dim = 3 * embed_dim // 4
value_dim = value_head_dims[i] * num_heads[i]
conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(
left_context_len[i], batch_size, key_dim
).numpy()
cached_nonlin_attn = torch.zeros(
1, batch_size, left_context_len[i], nonlin_attn_head_dim
).numpy()
cached_val1 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_val2 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
self.states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
self.states.append(processed_lens)
self.num_encoders = num_encoders
self.segment = T
self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
encoder_input = {"x": x.numpy()}
encoder_output = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
encoder_input[name] = tensors[0]
encoder_output.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
encoder_input[name] = tensors[1]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
encoder_input[name] = tensors[2]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
encoder_input[name] = tensors[3]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
encoder_input[name] = tensors[4]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
encoder_input[name] = tensors[5]
encoder_output.append(f"new_{name}")
for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
name = "embed_states"
embed_states = self.states[-2]
encoder_input[name] = embed_states
encoder_output.append(f"new_{name}")
# (batch_size,)
name = "processed_lens"
processed_lens = self.states[-1]
encoder_input[name] = processed_lens
encoder_output.append(f"new_{name}")
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
self.states = states
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
token_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += token_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/onnx_pretrained.py

View File

@ -0,0 +1,419 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--causal False
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file
./zipformer/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
"""
import argparse
import logging
import math
from typing import List, Tuple
import k2
import kaldifeat
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def run_encoder(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, its shape is (N, T', joiner_dim)
- encoder_out_lens, its shape is (N,)
"""
out = self.encoder.run(
[
self.encoder.get_outputs()[0].name,
self.encoder.get_outputs()[1].name,
],
{
self.encoder.get_inputs()[0].name: x.numpy(),
self.encoder.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, joiner_dim)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.run_decoder(decoder_input)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
# current_encoder_out's shape: (batch_size, joiner_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.run_joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
dtype=torch.int64,
)
decoder_out = model.run_decoder(decoder_input)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
token_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
for filename, hyp in zip(args.sound_files, hyps):
words = token_ids_to_words(hyp)
s += f"{filename}:\n{words}\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/optim.py

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/pretrained.py

View File

@ -0,0 +1,381 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script:
- For non-streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
- For streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./zipformer/exp/epoch-xx.pt`.
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
logging.info("Creating model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
# model forward
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
hyps = []
msg = f"Using {params.method}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
raise ValueError(f"Unsupported method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1,104 @@
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, ActivationBalancer is replaced with an identity operator;
Whiten is also replaced with an identity operator;
BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import (
Balancer,
Dropout3,
ScaleGrad,
SwooshL,
SwooshLOnnx,
SwooshR,
SwooshROnnx,
Whiten,
)
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, SwooshR):
d[name] = SwooshROnnx()
elif is_onnx and isinstance(m, SwooshL):
d[name] = SwooshLOnnx()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1,295 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
blank_penalty: float = 0.0,
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
if blank_penalty != 0.0:
logits[:, 0] -= blank_penalty
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
blank_penalty: float = 0.0,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
if blank_penalty != 0.0:
logits[:, 0] -= blank_penalty
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
streams: List[DecodeStream],
beam: float,
max_states: int,
max_contexts: int,
blank_penalty: float = 0.0,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape
assert B == len(streams)
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
if blank_penalty != 0.0:
logits[:, 0] -= blank_penalty
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyp_tokens[i]

View File

@ -23,7 +23,7 @@ Usage:
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 16 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search \
@ -38,8 +38,9 @@ from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import WenetSpeechAsrDataModule
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
@ -50,7 +51,7 @@ from streaming_beam_search import (
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_params, get_model
from icefall.checkpoint import (
average_checkpoints,
@ -58,7 +59,6 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
@ -123,10 +123,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--bpe-model",
type=str,
default="data/lang_char",
help="Path to the lang dir(containing lexicon, tokens, etc.)",
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
@ -181,18 +181,6 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--num-decode-streams",
type=int,
@ -294,7 +282,9 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
processed_lens = torch.cat(
[state_list[i][-1] for i in range(batch_size)], dim=0
)
batch_states.append(processed_lens)
return batch_states
@ -332,7 +322,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
cached_key_list = batch_states[layer_offset].chunk(
chunks=batch_size, dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
@ -363,7 +355,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
cached_embed_left_pad_list = batch_states[-2].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
@ -386,7 +380,11 @@ def streaming_forward(
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
(
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
@ -406,7 +404,9 @@ def streaming_forward(
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
@ -495,10 +495,7 @@ def decode_one_chunk(
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
encoder_out=encoder_out,
streams=decode_streams,
blank_penalty=params.blank_penalty,
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
@ -511,7 +508,6 @@ def decode_one_chunk(
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
@ -519,10 +515,11 @@ def decode_one_chunk(
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = unstack_states(new_states)
@ -540,7 +537,7 @@ def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -552,8 +549,8 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
lexicon:
The Lexicon.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
@ -580,7 +577,9 @@ def decode_dataset(
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(model=model, batch_size=1, device=device)
initial_states = get_init_states(
model=model, batch_size=1, device=device
)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
@ -596,12 +595,7 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
if audio.max() > 1:
logging.warning(
f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}."
f"Clipping to [-1, 1]."
)
audio = np.clip(audio, -1, 1)
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
@ -620,11 +614,8 @@ def decode_dataset(
decode_results.append(
(
decode_streams[i].id,
list(decode_streams[i].ground_truth.strip()),
[
lexicon.token_table[idx]
for idx in decode_streams[i].decoding_result()
],
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
@ -642,27 +633,25 @@ def decode_dataset(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
[
lexicon.token_table[idx]
for idx in decode_streams[i].decoding_result()
],
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
key = f"greedy_search_{key}"
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_{key}"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}_{key}"
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results}
@ -695,7 +684,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -713,7 +703,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -728,13 +718,14 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
@ -754,9 +745,13 @@ def main():
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -765,9 +760,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -794,9 +789,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -851,23 +846,23 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
wenetspeech = WenetSpeechAsrDataModule(args)
librispeech = LibriSpeechAsrDataModule(args)
dev_cuts = wenetspeech.valid_cuts()
test_net_cuts = wenetspeech.test_net_cuts()
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts]
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
lexicon=lexicon,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/subsampling.py

View File

@ -0,0 +1,407 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import warnings
import torch
from torch import Tensor, nn
from scaling import (
Balancer,
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaledConv2d,
ScaleGrad,
ScheduledFloat,
SwooshL,
SwooshR,
Whiten,
)
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1
)
self.hidden_balancer = Balancer(
hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer(
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand(
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x
def streaming_forward(
self,
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
Returns:
- The returned value has the same shape as x.
- Updated cached_left_pad.
"""
padding = self.padding
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
)
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
self.out_width = (((in_channels - 1) // 2) - 1) // 2
self.layer3_channels = layer3_channels
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat(
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
- output lengths, of shape (batch_size,)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item()
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item()
return x, x_lens, cached_left_pad
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
"""
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(
batch_size, channels, left_pad, freq
).to(device)
return cached_embed_left_pad

View File

@ -21,29 +21,33 @@
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
--world-size 8 \
--num-epochs 12 \
--start-epoch 1 \
--exp-dir zipformer/exp \
--training-subset L
--lr-epochs 1.5 \
--max-duration 350
# For mix precision training:
./zipformer/train.py \
--world-size 8 \
--num-epochs 12 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--training-subset L \
--lr-epochs 1.5 \
--max-duration 750
--full-libri 1 \
--max-duration 1000
# For streaming model training:
./zipformer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--full-libri 1 \
--max-duration 1000
It supports training with:
- transducer loss (default), with `--use-transducer True --use-ctc False`
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
"""
@ -57,10 +61,11 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -77,7 +82,6 @@ from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -87,7 +91,6 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
@ -139,42 +142,42 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--feedforward-dim",
type=str,
default="512,768,1024,1536,1024,768",
help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
)
parser.add_argument(
"--num-heads",
type=str,
default="4,4,4,8,4,4",
help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
default="192,256,384,512,384,256",
help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""",
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--query-head-dim",
type=str,
default="32",
help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--value-head-dim",
type=str,
default="12",
help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""",
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--pos-head-dim",
type=str,
default="4",
help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
@ -188,14 +191,16 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-unmasked-dim",
type=str,
default="192,192,256,256,256,192",
help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
)
parser.add_argument(
"--cnn-module-kernel",
type=str,
default="31,31,15,15,15,31",
help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.",
)
parser.add_argument(
@ -226,16 +231,31 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--chunk-size",
type=str,
default="16,32,64,-1",
help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""",
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False",
)
parser.add_argument(
"--left-context-frames",
type=str,
default="64,128,256,-1",
help="""Maximum left-contexts for causal training, measured in frames which will
be converted to a number of chunks. If splitting into chunks,
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
help="Maximum left-contexts for causal training, measured in frames which will "
"be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
)
parser.add_argument(
"--use-transducer",
type=str2bool,
default=True,
help="If True, use Transducer head.",
)
parser.add_argument(
"--use-ctc",
type=str2bool,
default=False,
help="If True, use CTC head.",
)
@ -302,13 +322,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--bpe-model",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
@ -335,47 +352,55 @@ def get_parser():
"--ref-duration",
type=float,
default=600,
help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""",
help="Reference batch duration for purposes of adjusting batch counts for setting various "
"schedules inside the model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="""The prune range for rnnt loss, it means how many symbols(context)
we are using to compute the loss""",
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="""The scale to smooth the loss with lm
(output of prediction network) part.""",
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="""The scale to smooth the loss with am (output of encoder network) part.""",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="""To get pruning ranges, we will calculate a simple version
loss(joiner is just addition), this simple loss also uses for
training (as a regularization item). We will scale the simple loss
with this parameter before adding to the final loss.""",
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
@ -408,7 +433,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)
@ -502,7 +527,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
@ -579,19 +604,32 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
if params.use_transducer:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
else:
decoder = None
joiner = None
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
)
return model
@ -659,9 +697,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -718,12 +753,12 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Compute loss given the model and its inputs.
Args:
params:
@ -753,11 +788,11 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts)
y = k2.RaggedTensor(y).to(device)
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
simple_loss, pruned_loss, ctc_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -766,21 +801,27 @@ def compute_loss(
lm_scale=params.lm_scale,
)
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = 0.0
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_transducer:
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training
@ -791,8 +832,11 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info
@ -800,7 +844,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -813,7 +857,7 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
sp=sp,
batch=batch,
is_training=False,
)
@ -836,7 +880,7 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@ -880,8 +924,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
saved_bad_model = False
def save_bad_model(suffix: str = ""):
@ -900,9 +942,6 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -912,7 +951,7 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
sp=sp,
batch=batch,
is_training=True,
)
@ -929,7 +968,7 @@ def train_one_epoch(
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5:
@ -950,7 +989,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -963,7 +1001,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -1020,7 +1057,7 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
@ -1073,14 +1110,15 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
logging.info(params)
@ -1135,23 +1173,25 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
wenetspeech = WenetSpeechAsrDataModule(args)
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = wenetspeech.train_cuts()
valid_cuts = wenetspeech.valid_cuts()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 15 seconds
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 15.0 here. Please see
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 15.0:
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
@ -1162,7 +1202,7 @@ def run(rank, world_size, args):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
@ -1186,18 +1226,20 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = wenetspeech.train_dataloaders(
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if False and not params.print_diagnostics:
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
sp=sp,
params=params,
)
@ -1222,7 +1264,7 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
graph_compiler=graph_compiler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@ -1256,7 +1298,7 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
graph_compiler: CharCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
@ -1266,8 +1308,8 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
graph_compiler:
The compiler to encode texts to ids.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
@ -1280,8 +1322,7 @@ def display_and_save_batch(
logging.info(f"features shape: {features.shape}")
texts = supervisions["text"]
y = graph_compiler.texts_to_ids(texts)
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
@ -1290,7 +1331,7 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -1306,7 +1347,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
sp=sp,
batch=batch,
is_training=True,
)
@ -1321,7 +1362,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(batch, params=params, sp=sp)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -1330,9 +1371,8 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/zipformer.py

File diff suppressed because it is too large Load Diff