Merge branch 'master' of github.com:marcoyang1998/icefall into mls_recipe

This commit is contained in:
marcoyang 2024-02-28 13:37:14 +08:00
commit 25f465c8c0
81 changed files with 13281 additions and 145 deletions

View File

@ -51,6 +51,7 @@ RUN pip install --no-cache-dir \
onnxruntime \
pytest \
sentencepiece>=0.1.96 \
pypinyin==0.50.0 \
six \
tensorboard \
typeguard

View File

@ -43,12 +43,12 @@ def get_torchaudio_version(torch_version):
def get_matrix():
k2_version = "1.24.4.dev20240218"
kaldifeat_version = "1.25.4.dev20240218"
version = "1.3"
k2_version = "1.24.4.dev20240223"
kaldifeat_version = "1.25.4.dev20240223"
version = "20240223"
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
torch_version += ["2.2.0"]
torch_version += ["2.2.0", "2.2.1"]
matrix = []
for p in python_version:

View File

@ -0,0 +1,140 @@
Finetune from a supervised pre-trained Zipformer model
======================================================
This tutorial shows you how to fine-tune a supervised pre-trained **Zipformer**
transducer model on a new dataset.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe
For illustration purpose, we fine-tune the Zipformer transducer model
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
own data for fine-tuning if you create a manifest for your new dataset.
Data preparation
----------------
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
Model preparation
-----------------
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
checkpoint of the model can be downloaded via the following command:
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
decoding on the GigaSpeech test sets:
.. code-block:: bash
./zipformer/decode_gigaspeech.py \
--epoch 99 \
--avg 1 \
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
--use-averaged-model 0 \
--max-duration 1000 \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 20.06 best for dev
For test, WER of different settings are:
greedy_search 19.27 best for test
Fine-tune
---------
Since LibriSpeech and GigaSpeech are both English dataset, we can initialize the whole
Zipformer model with the checkpoint downloaded in the previous step (otherwise we should consider
initializing the stateless decoder and joiner from scratch due to the mismatch of the output
vocabulary). The following command starts a fine-tuning experiment:
.. code-block:: bash
$ use_mux=0
$ do_finetune=1
$ ./zipformer/finetune.py \
--world-size 2 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
--use-fp16 1 \
--base-lr 0.0045 \
--bpe-model data/lang_bpe_500/bpe.model \
--do-finetune $do_finetune \
--use-mux $use_mux \
--master-port 13024 \
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
--max-duration 1000
The following arguments are related to fine-tuning:
- ``--base-lr``
The learning rate used for fine-tuning. We suggest to set a **small** learning rate for fine-tuning,
otherwise the model may forget the initialization very quickly. A reasonable value should be around
1/10 of the original lr, i.e 0.0045.
- ``--do-finetune``
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
need to set this to False.**
- ``--finetune-ckpt``
The path to the pre-trained checkpoint (used for initialization).
- ``--use-mux``
If True, mix the fine-tune data with the original training data by using `CutSet.mux <https://lhotse.readthedocs.io/en/latest/api.html#lhotse.supervision.SupervisionSet.mux>`_
This helps maintain the model's performance on the original domain if the original training
is available. **If you don't have the original training data, please set it to False.**
After fine-tuning, let's test the WERs. You can do this via the following command:
.. code-block:: bash
$ use_mux=0
$ do_finetune=1
$ ./zipformer/decode_gigaspeech.py \
--epoch 20 \
--avg 10 \
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
--use-averaged-model 1 \
--max-duration 1000 \
--decoding-method greedy_search
You should see numbers similar to the ones below:
.. code-block:: text
For dev, WER of different settings are:
greedy_search 13.47 best for dev
For test, WER of different settings are:
greedy_search 13.66 best for test
Compared to the original checkpoint, the fine-tuned model achieves much lower WERs
on the GigaSpeech test sets.

View File

@ -0,0 +1,15 @@
Fine-tune a pre-trained model
=============================
After pre-training on public available datasets, the ASR model is already capable of
performing general speech recognition with relatively high accuracy. However, the accuracy
could be still low on certain domains that are quite different from the original training
set. In this case, we can fine-tune the model with a small amount of additional labelled
data to improve the performance on new domains.
.. toctree::
:maxdepth: 2
:caption: Table of Contents
from_supervised/finetune_zipformer

View File

@ -17,3 +17,4 @@ We may add recipes for other tasks as well in the future.
Streaming-ASR/index
RNN-LM/index
TTS/index
Finetune/index

View File

@ -30,15 +30,15 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_gigaspeech_dev_test():
def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20
# number of seconds in a batch
batch_duration = 600
batch_duration = 1000
subsets = ("DEV", "TEST")
subsets = ("L", "M", "S", "XS", "DEV", "TEST")
device = torch.device("cpu")
if torch.cuda.is_available():
@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
logging.info(f"device: {device}")
for partition in subsets:
cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}",
storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
overwrite=True,
@ -80,7 +80,7 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_gigaspeech_dev_test()
compute_fbank_gigaspeech()
if __name__ == "__main__":

View File

@ -51,14 +51,6 @@ def get_parser():
"Determines batch size dynamically.",
)
parser.add_argument(
"--subset",
type=str,
default="XL",
choices=["XL", "L", "M", "S", "XS"],
help="Which subset to work with",
)
parser.add_argument(
"--num-splits",
type=int,
@ -84,7 +76,7 @@ def get_parser():
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
output_dir = f"data/fbank/{args.subset}_split"
output_dir = f"data/fbank/XL_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
@ -107,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args):
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz"
cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@ -121,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_{args.subset}_{idx}",
storage_path=f"{output_dir}/gigaspeech_feats_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,

View File

@ -16,17 +16,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from pathlib import Path
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Whether to use speed perturbation.",
)
return parser.parse_args()
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
@ -42,7 +56,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None
def preprocess_giga_speech():
def preprocess_giga_speech(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
@ -51,6 +65,10 @@ def preprocess_giga_speech():
"DEV",
"TEST",
"XL",
"L",
"M",
"S",
"XS",
)
logging.info("Loading manifest (may take 4 minutes)")
@ -71,7 +89,7 @@ def preprocess_giga_speech():
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
@ -94,11 +112,14 @@ def preprocess_giga_speech():
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
if args.perturb_speed:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
@ -107,7 +128,8 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_giga_speech()
args = get_args()
preprocess_giga_speech(args)
if __name__ == "__main__":

View File

@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
exit 1;
fi
# Download XL, DEV and TEST sets by default.
lhotse download gigaspeech --subset auto --host tsinghua \
lhotse download gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
mkdir -p data/manifests
lhotse prepare gigaspeech --subset auto -j $nj \
lhotse prepare gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests
fi
@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
python3 ./local/compute_fbank_gigaspeech_dev_test.py
log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
python3 ./local/compute_fbank_gigaspeech.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Prepare phone based lang"
log "Stage 9: Prepare transcript_words.txt and words.txt"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
if [ ! -f $lang_dir/transcript_words.txt ]; then
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
| jq '.text' \
@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare BPE based lang"
log "Stage 10: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare bigram P"
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare bigram P"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
done
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare G"
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Prepare G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
fi
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Compile HLG"
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do

View File

@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=30,
default=100,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@ -368,6 +368,8 @@ class GigaSpeechAsrDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
@ -417,6 +419,7 @@ class GigaSpeechAsrDataModule:
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)

View File

@ -416,6 +416,17 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--scan-for-oom-batches",
type=str2bool,
default=False,
help="""
Whether to scan for oom batches before training, this is helpful for
finding the suitable max_duration, you only need to run it once.
Caution: a little time consuming.
""",
)
parser.add_argument(
"--inf-check",
type=str2bool,
@ -1171,9 +1182,16 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_utt(c: Cut):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0
gigaspeech = GigaSpeechAsrDataModule(args)
train_cuts = gigaspeech.train_cuts()
train_cuts = train_cuts.filter(remove_short_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1187,9 +1205,10 @@ def run(rank, world_size, args):
)
valid_cuts = gigaspeech.dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,

View File

@ -0,0 +1,49 @@
# Results
## zipformer transducer model
This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details.
The modeling units are 500 BPEs trained on gigaspeech transcripts.
The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test set of gigaspeech (has 40 hours audios).
We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands.
The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-gigaspeech-20240219.tar.gz).
Here is the results of a small test set which has 20 commands, we list the results of every commands, for
each metric there are two columns, one for the original model trained on gigaspeech XL subset, the other
for the finetune model finetuned on commands dataset.
Commands | FN in positive set |FN in positive set | Recall | Recall | FP in negative set | FP in negative set| False alarm (time / hour) 40 hours | False alarm (time / hour) 40 hours |
-- | -- | -- | -- | --| -- | -- | -- | --
  | original | finetune | original | finetune | original | finetune | original | finetune
All | 43/307 | 4/307 | 86% | 98.7% | 1 | 24 | 0.025 | 0.6
Lights on | 6/17 | 0/17 | 64.7% | 100% | 1 | 9 | 0.025 | 0.225
Heat up | 5/14 | 1/14 | 64.3% | 92.9% | 0 | 1 | 0 | 0.025
Volume down | 4/18 | 0/18 | 77.8% | 100% | 0 | 2 | 0 | 0.05
Volume max | 4/17 | 0/17 | 76.5% | 100% | 0 | 0 | 0 | 0
Volume mute | 4/16 | 0/16 | 75.0% | 100% | 0 | 0 | 0 | 0
Too quiet | 3/17 | 0/17 | 82.4% | 100% | 0 | 4 | 0 | 0.1
Lights off | 3/17 | 0/17 | 82.4% | 100% | 0 | 2 | 0 | 0.05
Play music | 2/14 | 0/14 | 85.7% | 100% | 0 | 0 | 0 | 0
Bring newspaper | 2/13 | 1/13 | 84.6% | 92.3% | 0 | 0 | 0 | 0
Heat down | 2/16 | 2/16 | 87.5% | 87.5% | 0 | 1 | 0 | 0.025
Volume up | 2/18 | 0/18 | 88.9% | 100% | 0 | 1 | 0 | 0.025
Too loud | 1/13 | 0/13 | 92.3% | 100% | 0 | 0 | 0 | 0
Resume music | 1/14 | 0/14 | 92.9% | 100% | 0 | 0 | 0 | 0
Bring shoes | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
Switch language | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
Pause music | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
Bring socks | 1/12 | 0/12 | 91.7% | 100% | 0 | 0 | 0 | 0
Stop music | 0/15 | 0/15 | 100% | 100% | 0 | 0 | 0 | 0
Turn it up | 0/15 | 0/15 | 100% | 100% | 0 | 3 | 0 | 0.075
Turn it down | 0/16 | 0/16 | 100% | 100% | 0 | 1 | 0 | 0.025
This is the result of large test set, it has more than 200 commands, too many to list the details of each commands, so only an overall result here.
Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours
-- | -- | -- | -- | -- | -- | -- | -- | --
  | original | finetune | original | finetune | original | finetune | original | finetune
All | 622/3994 | 79/ 3994 | 83.6% | 97.9% | 18/19930 | 52/19930 | 0.45 | 1.3

85
egs/gigaspeech/KWS/prepare.sh Executable file
View File

@ -0,0 +1,85 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare gigaspeech dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.gigaspeech.done ]; then
pushd ../ASR
./prepare.sh --stage 0 --stop-stage 9
./prepare.sh --stage 11 --stop-stage 11
popd
pushd data/fbank
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
ln -svf $(realpath ../ASR/data/fbank/XL_split) .
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
popd
pushd data
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
popd
touch data/fbank/.gigaspeech.done
else
log "Gigaspeech dataset already exists, skipping."
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare open commands dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
pushd data
git clone https://github.com/pkufool/open-commands.git
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
pushd open-commands
./script/prepare.sh --stage 2 --stop-stage 2
./script/prepare.sh --stage 6 --stop-stage 6
popd
popd
pushd data/fbank
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
popd
touch data/fbank/.fluent_speech_commands.done
else
log "Fluent speech commands dataset already exists, skipping."
fi
fi

197
egs/gigaspeech/KWS/run.sh Executable file
View File

@ -0,0 +1,197 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export PYTHONPATH=../../../:$PYTHONPATH
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train a model."
if [ ! -e data/fbank/.gigaspeech.done ]; then
log "You need to run the prepare.sh first."
exit -1
fi
python ./zipformer/train.py \
--world-size 4 \
--exp-dir zipformer/exp \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--num-epochs 12 \
--lr-epochs 1.5 \
--use-fp16 1 \
--start-epoch 1 \
--subset XL \
--bpe-model data/lang_bpe_500/bpe.model \
--causal 1 \
--max-duration 1000
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Decode the model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 12 \
--avg 2 \
--exp-dir ./zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 1.0 \
--keywords-threshold 0.35 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Export the model."
python ./zipformer/export.py \
--epoch 12 \
--avg 2 \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 12 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Finetune the model"
# The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule
base_lr=0.0005
lr_epochs=100
lr_batches=100000
# We recommend to start from an averaged model
finetune_ckpt=zipformer/exp/pretrained.pt
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--exp-dir zipformer/exp_finetune \
--bpe-model data/lang_bpe_500/bpe.model \
--use-fp16 1 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1 \
--base-lr $base_lr \
--lr-epochs $lr_epochs \
--lr-batches $lr_batches \
--finetune-ckpt $finetune_ckpt \
--max-duration 1500
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./zipformer/exp_finetune \
--bpe-model data/lang_bpe_500/bpe.model \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 1.0 \
--keywords-threshold 0.35 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 2: Export the finetuned model."
python ./zipformer/export.py \
--epoch 10 \
--avg 2 \
--exp-dir ./zipformer/exp_finetune \
--tokens data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp_finetune \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 10 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi

1
egs/gigaspeech/KWS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1,477 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import inspect
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import lhotse
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class GigaSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
# GigaSpeech specific arguments
group.add_argument(
"--subset",
type=str,
default="XL",
help="Select the GigaSpeech subset (XS|S|M|L|XL)",
)
group.add_argument(
"--small-dev",
type=str2bool,
default=False,
help="Should we use only 1000 utterances for dev (speeds up training)",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info(f"About to get train {self.args.subset} cuts")
if self.args.subset == "XL":
filenames = glob.glob(
f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz"
)
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
else:
path = (
self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
)
cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)
@lru_cache()
def fsc_train_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
)
@lru_cache()
def fsc_valid_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands valid cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
)
@lru_cache()
def fsc_test_small_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands small test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
)
@lru_cache()
def fsc_test_large_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands large test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,689 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--keywords-file keywords.txt \
--beam-size 4
"""
import argparse
import logging
import math
import os
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import (
keywords_search,
)
from train import add_model_arguments, get_model, get_params
from lhotse.cut import Cut
from icefall import ContextGraph
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
@dataclass
class KwMetric:
TP: int = 0 # True positive
FN: int = 0 # False negative
FP: int = 0 # False positive
TN: int = 0 # True negative
FN_list: List[str] = field(default_factory=list)
FP_list: List[str] = field(default_factory=list)
TP_list: List[str] = field(default_factory=list)
def __str__(self) -> str:
return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})"
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--beam",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--keywords-file",
type=str,
help="File contains keywords.",
)
parser.add_argument(
"--test-set",
type=str,
default="small",
help="small or large",
)
parser.add_argument(
"--keywords-score",
type=float,
default=1.5,
help="""
The default boosting score (token level) for keywords. it will boost the
paths that match keywords to make them survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.35,
help="The default threshold (probability) to trigger the keyword.",
)
parser.add_argument(
"--num-tailing-blanks",
type=int,
default=1,
help="The number of tailing blanks should have after hitting one keyword.",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
keywords_graph: Optional[ContextGraph] = None,
) -> List[List[Tuple[str, Tuple[int, int]]]]:
"""Decode one batch and return the result in a list.
The length of the list equals to batch size, the i-th element contains the
triggered keywords for the i-th utterance in the given batch. The triggered
keywords are also a list, each of it contains a tuple of hitting keyword and
the corresponding start timestamps and end timestamps of the hitting keyword.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
keywords_graph:
The graph containing keywords.
Returns:
Return the decoding result. See above description for the format of
the returned list.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
ans_dict = keywords_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
keywords_graph=keywords_graph,
beam=params.beam,
num_tailing_blanks=params.num_tailing_blanks,
blank_penalty=params.blank_penalty,
)
hyps = []
for ans in ans_dict:
hyp = []
for hit in ans:
hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1])))
hyps.append(hyp)
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
keywords_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
keywords_graph:
The graph containing keywords.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 50
results = []
metric = {"all": KwMetric()}
for k in keywords:
metric[k] = KwMetric()
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params,
model=model,
sp=sp,
keywords_graph=keywords_graph,
batch=batch,
)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = ref_text.upper()
ref_words = ref_text.split()
hyp_words = [x[0] for x in hyp_words]
# for computing WER
this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
hyp_set = set(hyp_words) # each item is a keyword phrase
if len(hyp_words) > 1:
logging.warning(
f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
f"please check the transcript to see if it really has more "
f"than one keywords, if so consider splitting this audio and"
f"keep only one keyword for each audio."
)
hyp_str = " | ".join(
hyp_words
) # The triggered keywords for this utterance.
TP = False
FP = False
for x in hyp_set:
assert x in keywords, x # can only trigger keywords
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
TP = True
metric[x].TP += 1
metric[x].TP_list.append(f"({ref_text} -> {x})")
if (test_only_keywords and x != ref_text) or (
not test_only_keywords and x not in ref_text
):
FP = True
metric[x].FP += 1
metric[x].FP_list.append(f"({ref_text} -> {x})")
if TP:
metric["all"].TP += 1
if FP:
metric["all"].FP += 1
TN = True # all keywords are true negative then the summery is true negative.
FN = False
for x in keywords:
if x not in ref_text and x not in hyp_set:
metric[x].TN += 1
continue
TN = False
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
fn = True
for y in hyp_set:
if (test_only_keywords and y == ref_text) or (
not test_only_keywords and y in ref_text
):
fn = False
break
if fn:
FN = True
metric[x].FN += 1
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
if TN:
metric["all"].TN += 1
if FN:
metric["all"].FN += 1
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, metric
def save_results(
params: AttributeDict,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
metric: KwMetric,
):
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
with open(metric_filename, "w") as of:
width = 10
for key, item in sorted(
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
):
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
precision = (
0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
)
recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
s = f"{key}:\n"
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
s += f"\tAccuracy: {acc:.3f}\n"
s += f"\tPrecision: {precision:.3f}\n"
s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
if key != "all":
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
of.write(s + "\n")
if key == "all":
logging.info(s)
of.write(f"\n\n{params.keywords_config}")
logging.info("Wrote metric stats to {}".format(metric_filename))
@torch.no_grad()
def main():
parser = get_parser()
GigaSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "kws"
params.suffix = params.test_set
if params.iter > 0:
params.suffix += f"-iter-{params.iter}-avg-{params.avg}"
else:
params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"-score-{params.keywords_score}"
params.suffix += f"-threshold-{params.keywords_threshold}"
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}"
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
phrases = []
token_ids = []
keywords_scores = []
keywords_thresholds = []
keywords_config = []
with open(params.keywords_file, "r") as f:
for line in f.readlines():
keywords_config.append(line)
score = 0
threshold = 0
keyword = []
words = line.strip().upper().split()
for word in words:
word = word.strip()
if word[0] == ":":
score = float(word[1:])
continue
if word[0] == "#":
threshold = float(word[1:])
continue
keyword.append(word)
keyword = " ".join(keyword)
phrases.append(keyword)
token_ids.append(sp.encode(keyword))
keywords_scores.append(score)
keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)
keywords_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
keywords_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
ac_thresholds=keywords_thresholds,
)
keywords = set(phrases)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args)
test_cuts = gigaspeech.test_cuts()
test_dl = gigaspeech.test_dataloaders(test_cuts)
if params.test_set == "small":
test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts()
test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts)
test_sets = ["small-fsc", "test"]
test_dls = [test_fsc_small_dl, test_dl]
else:
assert params.test_set == "large", params.test_set
test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts()
test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts)
test_sets = ["large-fsc", "test"]
test_dls = [test_fsc_large_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results, metric = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
keywords_graph=keywords_graph,
keywords=keywords,
test_only_keywords="fsc" in test_set,
)
save_results(
params=params,
test_set_name=test_set,
results=results,
metric=metric,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1,644 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
# For non-streaming model training:
./zipformer/finetune.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 1000
# For streaming model training:
./zipformer/fintune.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--max-duration 1000
It supports training with:
- transducer loss (default), with `--use-transducer True --use-ctc False`
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
"""
import argparse
import copy
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
)
from train import (
add_model_arguments,
add_training_arguments,
compute_loss,
compute_validation_loss,
display_and_save_batch,
get_adjusted_batch_count,
get_model,
get_params,
load_checkpoint_if_available,
save_checkpoint,
scan_pessimistic_batches_for_oom,
set_batch_count,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune.
""",
)
parser.add_argument(
"--init-modules",
type=str,
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
use "encoder,joiner".
""",
)
parser.add_argument(
"--finetune-ckpt",
type=str,
default=None,
help="Fine-tuning from which checkpoint (a path to a .pt file)",
)
parser.add_argument(
"--continue-finetune",
type=str2bool,
default=False,
help="Continue finetuning or finetune from pre-trained model",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
add_training_arguments(parser)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
def load_model_params(
ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
# if module list is empty, load the whole model from ckpt
if not init_modules:
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
else:
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params) + 100000)
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
# if params.continue_finetune:
# set_batch_count(model, params.batch_idx_train)
# else:
# set_batch_count(model, params.batch_idx_train + 100000)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5:
return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
logging.info(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if params.continue_finetune:
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
else:
modules = params.init_modules.split(",") if params.init_modules else None
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_utt(c: Cut):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0
gigaspeech = GigaSpeechAsrDataModule(args)
if params.use_mux:
train_cuts = CutSet.mux(
gigaspeech.train_cuts(),
gigaspeech.fsc_train_cuts(),
weights=[0.9, 0.1],
)
else:
train_cuts = gigaspeech.fsc_train_cuts()
train_cuts = train_cuts.filter(remove_short_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = gigaspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = gigaspeech.fsc_valid_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
GigaSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1 @@
../../ASR/zipformer/gigaspeech_scoring.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
@ -31,6 +32,7 @@ from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import (
DecodingResults,
KeywordResult,
add_eos,
add_sos,
get_texts,
@ -789,6 +791,8 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor
ac_probs: Optional[List[float]] = None
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
@ -805,6 +809,8 @@ class Hypothesis:
# Context graph state
context_state: Optional[ContextState] = None
num_tailing_blanks: int = 0
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
@ -953,6 +959,241 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
return ans
def keywords_search(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
keywords_graph: ContextGraph,
beam: int = 4,
num_tailing_blanks: int = 0,
blank_penalty: float = 0,
) -> List[List[KeywordResult]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
keywords_graph:
A instance of ContextGraph containing keywords and their configurations.
beam:
Number of active paths during the beam search.
num_tailing_blanks:
The number of tailing blanks a keyword should be followed, this is for the
scenario that a keyword will be the prefix of another. In most cases, you
can just set it to 0.
blank_penalty:
The score used to penalize blank probability.
Returns:
Return a list of list of KeywordResult.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert keywords_graph is not None
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=keywords_graph.root,
timestamp=[],
ac_probs=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
sorted_ans = [[] for _ in range(N)]
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
if blank_penalty != 0:
logits[:, 0] -= blank_penalty
probs = logits.softmax(dim=-1) # (num_hyps, vocab_size)
log_probs = probs.log()
probs = probs.reshape(-1)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
hyp_probs = ragged_probs[i].tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
new_ac_probs = hyp.ac_probs[:]
context_score = 0
new_context_state = hyp.context_state
new_num_tailing_blanks = hyp.num_tailing_blanks + 1
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_ac_probs.append(hyp_probs[topk_indexes[k]])
(
context_score,
new_context_state,
_,
) = keywords_graph.forward_one_step(hyp.context_state, new_token)
new_num_tailing_blanks = 0
if new_context_state.token == -1: # root
new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id]
new_log_prob = topk_log_probs[k] + context_score
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
ac_probs=new_ac_probs,
context_state=new_context_state,
num_tailing_blanks=new_num_tailing_blanks,
)
B[i].add(new_hyp)
top_hyp = B[i].get_most_probable(length_norm=True)
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
if matched:
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
)
if (
matched
and top_hyp.num_tailing_blanks > num_tailing_blanks
and ac_prob >= matched_state.ac_threshold
):
keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :],
timestamps=top_hyp.timestamp[-matched_state.level :],
phrase=matched_state.phrase,
)
sorted_ans[i].append(keyword)
B[i] = HypothesisList()
B[i].add(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=keywords_graph.root,
timestamp=[],
ac_probs=[],
)
)
B = B + finalized_B
for i, hyps in enumerate(B):
top_hyp = hyps.get_most_probable(length_norm=True)
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
if matched:
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
)
if matched and ac_prob >= matched_state.ac_threshold:
keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :],
timestamps=top_hyp.timestamp[-matched_state.level :],
phrase=matched_state.phrase,
)
sorted_ans[i].append(keyword)
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,

View File

@ -479,14 +479,18 @@ class LibriSpeechAsrDataModule:
@lru_cache()
def gigaspeech_subset_small_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech subset-S cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz")
@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)

View File

@ -1,5 +1,70 @@
## Results
### SPGISpeech BPE training results (Zipformer Transducer)
#### 2024-01-05
#### Zipformer encoder + embedding decoder
Transducer: Zipformer encoder + stateless decoder.
The WERs are:
| | dev | val | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.08 | 2.14 | --epoch 30 --avg 10 |
| modified beam search | 2.05 | 2.09 | --epoch 30 --avg 10 --beam-size 4 |
| fast beam search | 2.07 | 2.17 | --epoch 30 --avg 10 --beam 20 --max-contexts 8 --max-states 64 |
**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
transcripts are orthographic or normalized. These WERs correspond to the normalized transcription
scenario.
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python zipformer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--num-workers 2 \
--max-duration 1000
```
The decoding command is:
```
# greedy search
python ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./zipformer/exp \
--max-duration 1000 \
--decoding-method greedy_search
# modified beam search
python ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./zipformer/exp \
--max-duration 1000 \
--decoding-method modified_beam_search
# fast beam search
python ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./zipformer/exp \
--max-duration 1000 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
### SPGISpeech BPE training results (Pruned Transducer)
#### 2022-05-11
@ -43,28 +108,28 @@ The decoding command is:
```
# greedy search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
# modified beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
# fast beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-spgispeech-pruned-transducer-stateless2>

View File

@ -102,6 +102,20 @@ class SPGISpeechAsrDataModule:
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=False,
help="When enabled, the last batch will be dropped",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--gap",
type=float,
@ -143,7 +157,7 @@ class SPGISpeechAsrDataModule:
group.add_argument(
"--num-workers",
type=int,
default=8,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
@ -176,7 +190,7 @@ class SPGISpeechAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
@ -223,11 +237,13 @@ class SPGISpeechAsrDataModule:
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
logging.info("Using DynamicBucketingSampler.")
@ -276,10 +292,12 @@ class SPGISpeechAsrDataModule:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
@ -303,6 +321,7 @@ class SPGISpeechAsrDataModule:
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1,382 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
Note: This is a example for spgispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script:
- For non-streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
- For streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./zipformer/exp/epoch-xx.pt`.
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
logging.info("Creating model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
# model forward
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
hyps = []
msg = f"Using {params.method}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
raise ValueError(f"Unsupported method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -602,11 +602,9 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids_with_bpe(texts)
if type(y) == list:
y = k2.RaggedTensor(y).to(device)
else:
y = y.to(device)
y = graph_compiler.texts_to_ids(texts, sep="/")
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,

View File

@ -0,0 +1,142 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
import lhotse
from pathlib import Path
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
fix_manifests,
validate_recordings_and_supervisions,
)
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--kaldi-dir",
type=str,
help="""The directory containing kaldi style manifest, namely wav.scp, text and segments.
""",
)
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bank bins.
""",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="""The directory where the lhotse manifests and features to write to.
""",
)
parser.add_argument(
"--dataset",
type=str,
help="""The name of dataset.
""",
)
parser.add_argument(
"--partition",
type=str,
help="""Could be something like train, valid, test and so on.
""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=True,
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
)
parser.add_argument(
"--num-jobs", type=int, default=50, help="The num of jobs to extract feature."
)
return parser.parse_args()
def prepare_cuts(args):
logging.info(f"Prepare cuts from {args.kaldi_dir}.")
recordings, supervisions, _ = lhotse.load_kaldi_data_dir(args.kaldi_dir, 16000)
recordings, supervisions = fix_manifests(recordings, supervisions)
validate_recordings_and_supervisions(recordings, supervisions)
cuts = CutSet.from_manifests(recordings=recordings, supervisions=supervisions)
return cuts
def compute_feature(args, cuts):
extractor = Fbank(FbankConfig(num_mel_bins=args.num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{args.dataset}_cuts_{args.partition}.jsonl.gz"
if (args.output_dir / cuts_filename).is_file():
logging.info(f"{cuts_filename} already exists - skipping.")
return
logging.info(f"Processing {cuts_filename}")
if "train" in args.partition:
if args.perturb_speed:
logging.info(f"Doing speed perturb")
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
cuts = cuts.compute_and_store_features(
extractor=extractor,
storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}",
# when an executor is specified, make more partitions
num_jobs=args.num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cuts.to_file(args.output_dir / cuts_filename)
def main(args):
args.kaldi_dir = Path(args.kaldi_dir)
args.output_dir = Path(args.output_dir)
cuts = prepare_cuts(args)
compute_feature(args, cuts)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
main(args)

View File

@ -0,0 +1,275 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
import re
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
from icefall.utils import text_to_pinyin
def get_parser():
parser = argparse.ArgumentParser(
description="Prepare lang for pinyin",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
parser.add_argument(
"--token-type",
default="full_with_tone",
type=str,
help="""The type of pinyin, should be in:
full_with_tone: zhōng guó
full_no_tone: zhong guo
partial_with_tone: zh ōng g
partial_no_tone: zh ong g uo
""",
)
parser.add_argument(
"--pinyin-errors",
default="split",
type=str,
help="""How to handle characters that has no pinyin,
see `text_to_pinyin` in icefall/utils.py for details
""",
)
return parser
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
args, token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for word in words:
tokens = text_to_pinyin(
word.strip(), mode=args.token_type, errors=args.pinyin_errors
)
if contain_oov(token_sym_table, tokens):
print(f"Word : {word} contains OOV token, skipping.")
continue
lexicon.append((word, tokens))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(args, words: List[str]) -> Dict[str, int]:
"""Generate tokens from the given word list.
Args:
words:
A list that contains words to generate tokens.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
for word in words:
word = word.strip()
tokens_list = text_to_pinyin(
word, mode=args.token_type, errors=args.pinyin_errors
)
for token in tokens_list:
if token not in tokens:
tokens[token] = len(tokens)
return tokens
def main():
parser = get_parser()
args = parser.parse_args()
lang_dir = Path(args.lang_dir)
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(args, words)
lexicon = generate_lexicon(args, token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -364,4 +364,18 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
--lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
--vocab-size 5537 \
--master-port 12340
fi
fi
if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
log "Stage 22: Prepare pinyin based lang"
for token in full_with_tone partial_with_tone; do
lang_dir=data/lang_${token}
if [ ! -f $lang_dir/tokens.txt ]; then
cp data/lang_char/words.txt $lang_dir/words.txt
python local/prepare_pinyin.py \
--token-type $token \
--lang-dir $lang_dir
fi
python ./local/compile_lg.py --lang-dir $lang_dir
done
fi

View File

@ -0,0 +1,58 @@
# Results
## zipformer transducer model
This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details.
The modeling units are partial pinyin (i.e initials and finals) with tone.
The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test net of wenetspeech (has 23 hours audios).
We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands.
The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-wenetspeech-20240219.tar.gz).
Here is the results of a small test set which has 20 commands, we list the results of every commands, for
each metric there are two columns, one for the original model trained on wenetspeech L subset, the other
for the finetune model finetuned on in house commands dataset (has 90 hours audio).
> You can see that the performance of the original model is very poor, I think the reason is the test commands are all collected from real product scenarios which are very different from the scenarios wenetspeech dataset was collected. After finetuning, the performance improves a lot.
Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours
-- | -- | -- | -- | -- | -- | -- | -- | --
  | original | finetune | original | finetune | original | finetune | original | finetune
All | 426 / 985 | 40/985 | 56.8% | 95.9% | 7 | 1 | 0.3 | 0.04
下一个 | 5/50 | 0/50 | 90% | 100% | 3 | 0 | 0.13 | 0
开灯 | 19/49 | 2/49 | 61.2% | 95.9% | 0 | 0 | 0 | 0
第一个 | 11/50 | 3/50 | 78% | 94% | 3 | 0 | 0.13 | 0
声音调到最大 | 39/50 | 7/50 | 22% | 86% | 0 | 0 | 0 | 0
暂停音乐 | 36/49 | 1/49 | 26.5% | 98% | 0 | 0 | 0 | 0
暂停播放 | 33/49 | 2/49 | 32.7% | 95.9% | 0 | 0 | 0 | 0
打开卧室灯 | 33/49 | 1/49 | 32.7% | 98% | 0 | 0 | 0 | 0
关闭所有灯 | 27/50 | 0/50 | 46% | 100% | 0 | 0 | 0 | 0
关灯 | 25/48 | 2/48 | 47.9% | 95.8% | 1 | 1 | 0.04 | 0.04
关闭导航 | 25/48 | 1/48 | 47.9% | 97.9% | 0 | 0 | 0 | 0
打开蓝牙 | 24/47 | 0/47 | 48.9% | 100% | 0 | 0 | 0 | 0
下一首歌 | 21/50 | 1/50 | 58% | 98% | 0 | 0 | 0 | 0
换一首歌 | 19/50 | 5/50 | 62% | 90% | 0 | 0 | 0 | 0
继续播放 | 19/50 | 2/50 | 62% | 96% | 0 | 0 | 0 | 0
打开闹钟 | 18/49 | 2/49 | 63.3% | 95.9% | 0 | 0 | 0 | 0
打开音乐 | 17/49 | 0/49 | 65.3% | 100% | 0 | 0 | 0 | 0
打开导航 | 17/48 | 0/49 | 64.6% | 100% | 0 | 0 | 0 | 0
打开电视 | 15/50 | 0/49 | 70% | 100% | 0 | 0 | 0 | 0
大点声 | 12/50 | 5/50 | 76% | 90% | 0 | 0 | 0 | 0
小点声 | 11/50 | 6/50 | 78% | 88% | 0 | 0 | 0 | 0
This is the result of large test set, it has more than 100 commands, too many to list the details of each commands, so only an overall result here. We also list the results of two weak up words 小云小云 (only test setand 你好问问 (both training and test sets). For 你好问问, we have to finetune models, one is finetuned on 你好问问 and our in house commands data, the other finetuned on only 你好问问. Both models perform much better than original model, the one finetuned on only 你好问问 behaves slightly better than the other.
> 小云小云 test set and 你好问问 training, dev and test sets are available at https://github.com/pkufool/open-commands
Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours
-- | -- | -- | -- | -- | -- | -- | -- | --
  | original | finetune | original | finetune | original | finetune | original | finetune
large | 2429/4505 | 477 / 4505 | 46.1% | 89.4% | 50 | 41 | 2.17 | 1.78
小云小云clean) | 30/100 | 40/100 | 70% | 60% | 0 | 0 | 0 | 0
小云小云noisy) | 118/350 | 154/350 | 66.3% | 56% | 0 | 0 | 0 | 0
你好问问finetune with all keywords data) | 2236/10641 | 678/10641 | 79% | 93.6% | 0 | 0 | 0 | 0
你好问问(finetune with only 你好问问) | 2236/10641 | 249/10641 | 79% | 97.7% | 0 | 0 | 0 | 0

90
egs/wenetspeech/KWS/prepare.sh Executable file
View File

@ -0,0 +1,90 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare wewetspeech dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.wewetspeech.done ]; then
pushd ../ASR
./prepare.sh --stage 0 --stop-stage 17
./prepare.sh --stage 22 --stop-stage 22
popd
pushd data/fbank
ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) .
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) .
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) .
ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/L_split_1000) .
ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/M_split_1000) .
ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/S_split_1000) .
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
popd
pushd data
ln -svf $(realpath ../ASR/data/lang_partial_tone) .
popd
touch data/fbank/.wewetspeech.done
else
log "WenetSpeech dataset already exists, skipping."
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare open commands dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.cn_speech_commands.done ]; then
pushd data
git clone https://github.com/pkufool/open-commands.git
ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt
pushd open-commands
./script/prepare.sh --stage 1 --stop-stage 1
./script/prepare.sh --stage 3 --stop-stage 5
popd
popd
pushd data/fbank
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) .
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) .
popd
touch data/fbank/.cn_speech_commands.done
else
log "CN speech commands dataset already exists, skipping."
fi
fi

201
egs/wenetspeech/KWS/run.sh Executable file
View File

@ -0,0 +1,201 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export PYTHONPATH=../../../:$PYTHONPATH
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train a model."
if [ ! -e data/fbank/.wenetspeech.done ]; then
log "You need to run the prepare.sh first."
exit -1
fi
python ./zipformer/train.py \
--world-size 4 \
--exp-dir zipformer/exp \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--num-epochs 18 \
--lr-epochs 1.5 \
--use-fp16 1 \
--start-epoch 1 \
--training-subset L \
--pinyin-type partial_with_tone \
--causal 1 \
--lang-dir data/lang_partial_tone \
--max-duration 1000
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Decode the model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 18 \
--avg 2 \
--exp-dir ./zipformer/exp \
--tokens ./data/lang_partial_tone/tokens.txt \
--pinyin-type partial_with_tone \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 1.5 \
--keywords-threshold 0.1 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Export the model."
python ./zipformer/export.py \
--epoch 18 \
--avg 2 \
--exp-dir ./zipformer/exp \
--tokens data/lang_partial_tone/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp \
--tokens data/lang_partial_tone/tokens.txt \
--epoch 18 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Finetune the model"
# The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule
base_lr=0.0005
lr_epochs=100
lr_batches=100000
# We recommend to start from an averaged model
finetune_ckpt=zipformer/exp/pretrained.pt
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--exp-dir zipformer/exp_finetune
--lang-dir ./data/lang_partial_tone \
--pinyin-type partial_with_tone \
--use-fp16 1 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1 \
--base-lr $base_lr \
--lr-epochs $lr_epochs \
--lr-batches $lr_batches \
--finetune-ckpt $finetune_ckpt \
--max-duration 1500
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./zipformer/exp_finetune \
--tokens ./data/lang_partial_tone/tokens.txt \
--pinyin-type partial_with_tone \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 0.000001 \
--keywords-threshold 0.35 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 2: Export the finetuned model."
python ./zipformer/export.py \
--epoch 10 \
--avg 2 \
--exp-dir ./zipformer/exp_finetune \
--tokens data/lang_partial_tone/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp_finetune \
--tokens data/lang_partial_tone/tokens.txt \
--epoch 10 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi

1
egs/wenetspeech/KWS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1,459 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
load_manifest,
load_manifest_lazy,
set_caching_enabled,
)
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class WenetSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for using",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=300000,
drop_last=True,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_dl.sampler.load_state_dict(sampler_state_dict)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
batch_size=None,
sampler=valid_sampler,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest_lazy(
self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
@lru_cache()
def test_net_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_NET cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
@lru_cache()
def test_meeting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETING cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
@lru_cache()
def cn_speech_commands_small_cuts(self) -> CutSet:
logging.info("About to get cn speech commands small cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cn_speech_commands_cuts_small.jsonl.gz"
)
@lru_cache()
def cn_speech_commands_large_cuts(self) -> CutSet:
logging.info("About to get cn speech commands large cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cn_speech_commands_cuts_large.jsonl.gz"
)
@lru_cache()
def nihaowenwen_dev_cuts(self) -> CutSet:
logging.info("About to get nihaowenwen dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "nihaowenwen_cuts_dev.jsonl.gz"
)
@lru_cache()
def nihaowenwen_test_cuts(self) -> CutSet:
logging.info("About to get nihaowenwen test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "nihaowenwen_cuts_test.jsonl.gz"
)
@lru_cache()
def nihaowenwen_train_cuts(self) -> CutSet:
logging.info("About to get nihaowenwen train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "nihaowenwen_cuts_train.jsonl.gz"
)
@lru_cache()
def xiaoyun_clean_cuts(self) -> CutSet:
logging.info("About to get xiaoyun clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "xiaoyun_cuts_clean.jsonl.gz"
)
@lru_cache()
def xiaoyun_noisy_cuts(self) -> CutSet:
logging.info("About to get xiaoyun noisy cuts")
return load_manifest_lazy(
self.args.manifest_dir / "xiaoyun_cuts_noisy.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,767 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) fast beam search (LG)
./zipformer/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest_oracle
If you use fast_beam_search_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_LG":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}_" + key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest_oracle",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
dev_cuts = wenetspeech.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,737 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import (
keywords_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
make_pad_mask,
num_tokens,
setup_logger,
store_transcripts,
str2bool,
text_to_pinyin,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
@dataclass
class KwMetric:
TP: int = 0 # True positive
FN: int = 0 # False negative
FP: int = 0 # False positive
TN: int = 0 # True negative
FN_list: List[str] = field(default_factory=list)
FP_list: List[str] = field(default_factory=list)
TP_list: List[str] = field(default_factory=list)
def __str__(self) -> str:
return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})"
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/lang_partial_tone/tokens.txt",
help="The path to the token.txt",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--pinyin-type",
type=str,
help="The type of pinyin used as the modeling units.",
)
parser.add_argument(
"--keywords-file",
type=str,
help="File contains keywords.",
)
parser.add_argument(
"--test-set",
type=str,
default="small",
help="small or large",
)
parser.add_argument(
"--keywords-score",
type=float,
default=1.5,
help="""
The default boosting score (token level) for keywords. it will boost the
paths that match keywords to make them survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.35,
help="The default threshold (probability) to trigger the keyword.",
)
parser.add_argument(
"--num-tailing-blanks",
type=int,
default=1,
help="The number of tailing blanks should have after hitting one keyword.",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
keywords_graph: ContextGraph,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
ans_dict = keywords_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
keywords_graph=keywords_graph,
beam=params.beam_size,
num_tailing_blanks=8,
)
hyps = []
for ans in ans_dict:
hyp = []
for hit in ans:
hyp.append(
(
hit.phrase,
(hit.timestamps[0], hit.timestamps[-1]),
)
)
hyps.append(hyp)
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
keywords_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 20
results = []
metric = {"all": KwMetric()}
for k in keywords:
metric[k] = KwMetric()
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params,
model=model,
keywords_graph=keywords_graph,
batch=batch,
)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = list(ref_text)
hyp_words = [x[0] for x in hyp_words]
this_batch.append((cut_id, ref_words, list("".join(hyp_words))))
hyp_set = set(hyp_words)
if len(hyp_words) > 1:
logging.warning(
f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
f"please check the transcript to see if it really has more "
f"than one keywords, if so consider splitting this audio and"
f"keep only one keyword for each audio."
)
hyp_str = " | ".join(
hyp_words
) # The triggered keywords for this utterance.
TP = False
FP = False
for x in hyp_set:
assert x in keywords, x # can only trigger keywords
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
TP = True
metric[x].TP += 1
metric[x].TP_list.append(f"({ref_text} -> {x})")
if (test_only_keywords and x != ref_text) or (
not test_only_keywords and x not in ref_text
):
FP = True
metric[x].FP += 1
metric[x].FP_list.append(f"({ref_text} -> {x})")
if TP:
metric["all"].TP += 1
if FP:
metric["all"].FP += 1
TN = True # all keywords are true negative then the summery is true negative.
FN = False
for x in keywords:
if x not in ref_text and x not in hyp_set:
metric[x].TN += 1
continue
TN = False
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
fn = True
for y in hyp_set:
if (test_only_keywords and y == ref_text) or (
not test_only_keywords and y in ref_text
):
fn = False
break
if fn:
FN = True
metric[x].FN += 1
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
if TN:
metric["all"].TN += 1
if FN:
metric["all"].FN += 1
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, metric
def save_results(
params: AttributeDict,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
metric: KwMetric,
):
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
with open(metric_filename, "w") as of:
width = 10
for key, item in sorted(
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
):
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
precision = (
0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
)
recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
s = f"{key}:\n"
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
s += f"\tAccuracy: {acc:.3f}\n"
s += f"\tPrecision: {precision:.3f}\n"
s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
if key != "all":
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
of.write(s + "\n")
if key == "all":
logging.info(s)
of.write(f"\n\n{params.keywords_config}")
logging.info("Wrote metric stats to {}".format(metric_filename))
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "kws"
params.suffix = params.test_set
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"-score-{params.keywords_score}"
params.suffix += f"-threshold-{params.keywords_threshold}"
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}"
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
phrases = []
token_ids = []
keywords_scores = []
keywords_thresholds = []
keywords_config = []
with open(params.keywords_file, "r") as f:
for line in f.readlines():
keywords_config.append(line)
score = 0
threshold = 0
keyword = []
words = line.strip().upper().split()
for word in words:
word = word.strip()
if word[0] == ":":
score = float(word[1:])
continue
if word[0] == "#":
threshold = float(word[1:])
continue
keyword.append(word)
keyword = "".join(keyword)
tmp_ids = []
kws_py = text_to_pinyin(keyword, mode=params.pinyin_type)
for k in kws_py:
if k in token_table:
tmp_ids.append(token_table[k])
else:
logging.warning(f"Containing OOV tokens, skipping line : {line}")
tmp_ids = []
break
if tmp_ids:
logging.info(f"Adding keyword : {keyword}")
phrases.append(keyword)
token_ids.append(tmp_ids)
keywords_scores.append(score)
keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)
keywords_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
keywords_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
ac_thresholds=keywords_thresholds,
)
keywords = set(phrases)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
test_net_cuts = wenetspeech.test_net_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
cn_commands_small_cuts = wenetspeech.cn_speech_commands_small_cuts()
cn_commands_small_cuts = cn_commands_small_cuts.filter(remove_short_utt)
cn_commands_small_dl = wenetspeech.test_dataloaders(cn_commands_small_cuts)
cn_commands_large_cuts = wenetspeech.cn_speech_commands_large_cuts()
cn_commands_large_cuts = cn_commands_large_cuts.filter(remove_short_utt)
cn_commands_large_dl = wenetspeech.test_dataloaders(cn_commands_large_cuts)
nihaowenwen_test_cuts = wenetspeech.nihaowenwen_test_cuts()
nihaowenwen_test_cuts = nihaowenwen_test_cuts.filter(remove_short_utt)
nihaowenwen_test_dl = wenetspeech.test_dataloaders(nihaowenwen_test_cuts)
xiaoyun_clean_cuts = wenetspeech.xiaoyun_clean_cuts()
xiaoyun_clean_cuts = xiaoyun_clean_cuts.filter(remove_short_utt)
xiaoyun_clean_dl = wenetspeech.test_dataloaders(xiaoyun_clean_cuts)
xiaoyun_noisy_cuts = wenetspeech.xiaoyun_noisy_cuts()
xiaoyun_noisy_cuts = xiaoyun_noisy_cuts.filter(remove_short_utt)
xiaoyun_noisy_dl = wenetspeech.test_dataloaders(xiaoyun_noisy_cuts)
test_sets = []
test_dls = []
if params.test_set == "large":
test_sets += ["cn_commands_large", "test_net"]
test_dls += [cn_commands_large_dl, test_net_dl]
else:
assert params.test_set == "small", params.test_set
test_sets += [
"cn_commands_small",
"nihaowenwen",
"xiaoyun_clean",
"xiaoyun_noisy",
"test_net",
]
test_dls += [
cn_commands_small_dl,
nihaowenwen_test_dl,
xiaoyun_clean_dl,
xiaoyun_noisy_dl,
test_net_dl,
]
for test_set, test_dl in zip(test_sets, test_dls):
results, metric = decode_dataset(
dl=test_dl,
params=params,
model=model,
keywords_graph=keywords_graph,
keywords=keywords,
test_only_keywords="test_net" not in test_set,
)
save_results(
params=params,
test_set_name=test_set,
results=results,
metric=metric,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1,814 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model finetuning:
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 1000
# For non-streaming model finetuning with mux (original dataset):
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--use-mux 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 1000
# For streaming model finetuning:
./zipformer/fintune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--max-duration 1000
# For streaming model finetuning with mux (original dataset):
./zipformer/fintune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--max-duration 1000
"""
import argparse
import copy
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
text_to_pinyin,
)
from train import (
add_model_arguments,
add_training_arguments,
compute_validation_loss,
display_and_save_batch,
get_adjusted_batch_count,
get_model,
get_params,
load_checkpoint_if_available,
save_checkpoint,
scan_pessimistic_batches_for_oom,
set_batch_count,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune.
""",
)
parser.add_argument(
"--init-modules",
type=str,
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
use "encoder,joiner".
""",
)
parser.add_argument(
"--finetune-ckpt",
type=str,
default=None,
help="Fine-tuning from which checkpoint (a path to a .pt file)",
)
parser.add_argument(
"--continue-finetune",
type=str2bool,
default=False,
help="Continue finetuning or finetune from pre-trained model",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_partial_tone",
help="Path to the pinyin lang directory",
)
parser.add_argument(
"--pinyin-type",
type=str,
default="partial_with_tone",
help="""
The style of the output pinyin, should be:
full_with_tone : zhōng guó
full_no_tone : zhong guo
partial_with_tone : zh ōng g
partial_no_tone : zh ong g uo
""",
)
parser.add_argument(
"--pinyin-errors",
default="split",
type=str,
help="""How to handle characters that has no pinyin,
see `text_to_pinyin` in icefall/utils.py for details
""",
)
add_training_arguments(parser)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
def load_model_params(
ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
# if module list is empty, load the whole model from ckpt
if not init_modules:
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
else:
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Zipformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts, sep="/")
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = 0.0
if params.use_transducer:
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params) + 100000)
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
if not params.use_transducer:
params.ctc_loss_scale = 1.0
logging.info(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if params.continue_finetune:
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
else:
modules = params.init_modules.split(",") if params.init_modules else None
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_utt(c: Cut):
if c.duration > 15:
return False
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0
wenetspeech = WenetSpeechAsrDataModule(args)
if params.use_mux:
train_cuts = CutSet.mux(
wenetspeech.train_cuts(),
wenetspeech.nihaowenwen_train_cuts(),
weights=[0.9, 0.1],
)
else:
train_cuts = wenetspeech.nihaowenwen_train_cuts()
def encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = "/".join(
text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
)
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_utt)
train_cuts = train_cuts.map(encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = wenetspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = wenetspeech.nihaowenwen_dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_cuts = valid_cuts.map(encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -54,7 +54,7 @@ class CharCtcTrainingGraphCompiler(object):
self.sos_id = self.token_table[sos_token]
self.eos_id = self.token_table[eos_token]
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
def texts_to_ids(self, texts: List[str], sep: str = "") -> List[List[int]]:
"""Convert a list of texts to a list-of-list of token IDs.
Args:
@ -63,36 +63,21 @@ class CharCtcTrainingGraphCompiler(object):
An example containing two strings is given below:
['你好中国', '北京欢迎您']
sep:
The separator of the items in one sequence, mainly no separator for
Chinese (one character a token), "/" for Chinese characters plus BPE
token and pinyin tokens.
Returns:
Return a list-of-list of token IDs.
"""
assert sep in ("", "/"), sep
ids: List[List[int]] = []
whitespace = re.compile(r"([ \t])")
for text in texts:
text = re.sub(whitespace, "", text)
sub_ids = [
self.token_table[txt] if txt in self.token_table else self.oov_id
for txt in text
]
ids.append(sub_ids)
return ids
def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts (which include chars and bpes)
to a list-of-list of token IDs.
Args:
texts:
It is a list of strings.
An example containing two strings is given below:
[['', '', '▁C', 'hina'], ['','', '', 'welcome', '']
Returns:
Return a list-of-list of token IDs.
"""
ids: List[List[int]] = []
for text in texts:
text = text.split("/")
if sep == "":
text = re.sub(whitespace, "", text)
else:
text = text.split(sep)
sub_ids = [
self.token_table[txt] if txt in self.token_table else self.oov_id
for txt in text

View File

@ -17,7 +17,7 @@
import os
import shutil
from collections import deque
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
class ContextState:
@ -31,6 +31,9 @@ class ContextState:
node_score: float,
output_score: float,
is_end: bool,
level: int,
phrase: str = "",
ac_threshold: float = 1.0,
):
"""Create a ContextState.
@ -51,6 +54,15 @@ class ContextState:
the output node for current node.
is_end:
True if current token is the end of a context.
level:
The distance from current node to root.
phrase:
The context phrase of current state, the value is valid only when
current state is end state (is_end == True).
ac_threshold:
The acoustic threshold (probability) of current context phrase, the
value is valid only when current state is end state (is_end == True).
Note: ac_threshold only used in keywords spotting.
"""
self.id = id
self.token = token
@ -58,7 +70,10 @@ class ContextState:
self.node_score = node_score
self.output_score = output_score
self.is_end = is_end
self.level = level
self.next = {}
self.phrase = phrase
self.ac_threshold = ac_threshold
self.fail = None
self.output = None
@ -75,7 +90,7 @@ class ContextGraph:
beam search.
"""
def __init__(self, context_score: float):
def __init__(self, context_score: float, ac_threshold: float = 1.0):
"""Initialize a ContextGraph with the given ``context_score``.
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
@ -87,8 +102,12 @@ class ContextGraph:
Note: This is just the default score for each token, the users can manually
specify the context_score for each word/phrase (i.e. different phrase might
have different token score).
ac_threshold:
The acoustic threshold (probability) to trigger the word/phrase, this argument
is used only when applying the graph to keywords spotting system.
"""
self.context_score = context_score
self.ac_threshold = ac_threshold
self.num_nodes = 0
self.root = ContextState(
id=self.num_nodes,
@ -97,6 +116,7 @@ class ContextGraph:
node_score=0,
output_score=0,
is_end=False,
level=0,
)
self.root.fail = self.root
@ -136,7 +156,13 @@ class ContextGraph:
node.output_score += 0 if output is None else output.output_score
queue.append(node)
def build(self, token_ids: List[Tuple[List[int], float]]):
def build(
self,
token_ids: List[List[int]],
phrases: Optional[List[str]] = None,
scores: Optional[List[float]] = None,
ac_thresholds: Optional[List[float]] = None,
):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
@ -145,52 +171,80 @@ class ContextGraph:
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of tuple of
token list and its customized score, the token list contains the token ids
The given token lists to build the ContextGraph, it is a list of
token list, the token list contains the token ids
for a word/phrase. The token id could be an id of a char
(modeling with single Chinese char) or an id of a BPE
(modeling with BPEs). The score is the total score for current token list,
(modeling with BPEs).
phrases:
The given phrases, they are the original text of the token_ids, the
length of `phrases` MUST be equal to the length of `token_ids`.
scores:
The customize boosting score(token level) for each word/phrase,
0 means using the default value (i.e. self.context_score).
It is a list of floats, and the length of `scores` MUST be equal to
the length of `token_ids`.
ac_thresholds:
The customize trigger acoustic threshold (probability) for each phrase,
0 means using the default value (i.e. self.ac_threshold). It is
used only when this graph applied for the keywords spotting system.
The length of `ac_threshold` MUST be equal to the length of `token_ids`.
Note: The phrases would have shared states, the score of the shared states is
the maximum value among all the tokens sharing this state.
the MAXIMUM value among all the tokens sharing this state.
"""
for (tokens, score) in token_ids:
num_phrases = len(token_ids)
if phrases is not None:
assert len(phrases) == num_phrases, (len(phrases), num_phrases)
if scores is not None:
assert len(scores) == num_phrases, (len(scores), num_phrases)
if ac_thresholds is not None:
assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases)
for index, tokens in enumerate(token_ids):
phrase = "" if phrases is None else phrases[index]
score = 0.0 if scores is None else scores[index]
ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index]
node = self.root
# If has customized score using the customized token score, otherwise
# using the default score
context_score = (
self.context_score if score == 0.0 else round(score / len(tokens), 2)
)
context_score = self.context_score if score == 0.0 else score
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
for i, token in enumerate(tokens):
node_next = {}
if token not in node.next:
self.num_nodes += 1
node_id = self.num_nodes
token_score = context_score
is_end = i == len(tokens) - 1
node_score = node.node_score + context_score
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=context_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
level=i + 1,
phrase=phrase if is_end else "",
ac_threshold=threshold if is_end else 0.0,
)
else:
# node exists, get the score of shared state.
token_score = max(context_score, node.next[token].token_score)
node_id = node.next[token].id
node_next = node.next[token].next
node.next[token].token_score = token_score
node_score = node.node_score + token_score
node.next[token].node_score = node_score
is_end = i == len(tokens) - 1 or node.next[token].is_end
node_score = node.node_score + token_score
node.next[token] = ContextState(
id=node_id,
token=token,
token_score=token_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
node.next[token].next = node_next
node.next[token].output_score = node_score if is_end else 0
node.next[token].is_end = is_end
if i == len(tokens) - 1:
node.next[token].phrase = phrase
node.next[token].ac_threshold = threshold
node = node.next[token]
self._fill_fail_output()
def forward_one_step(
self, state: ContextState, token: int
) -> Tuple[float, ContextState]:
self, state: ContextState, token: int, strict_mode: bool = True
) -> Tuple[float, ContextState, ContextState]:
"""Search the graph with given state and token.
Args:
@ -198,9 +252,27 @@ class ContextGraph:
The given token containing trie node to start.
token:
The given token.
strict_mode:
If the `strict_mode` is True, it can match multiple phrases simultaneously,
and will continue to match longer phrase after matching a shorter one.
If the `strict_mode` is False, it can only match one phrase at a time,
when it matches a phrase, then the state will fall back to root state
(i.e. forgetting all the history state and starting a new match). If
the matched state have multiple outputs (node.output is not None), the
longest phrase will be return.
For example, if the phrases are `he`, `she` and `shell`, the query is
`like shell`, when `strict_mode` is True, the query will match `he` and
`she` at token `e` and `shell` at token `l`, while when `strict_mode`
if False, the query can only match `she`(`she` is longer than `he`, so
`she` not `he`) at token `e`.
Caution: When applying this graph for keywords spotting system, the
`strict_mode` MUST be True.
Returns:
Return a tuple of score and next state.
Return a tuple of boosting score for current state, next state and matched
state (if any). Note: Only returns the matched state with longest phrase of
current state, even if there are multiple matches phrases. If no phrase
matched, the matched state is None.
"""
node = None
score = 0
@ -224,7 +296,49 @@ class ContextGraph:
# The score of the fail path
score = node.node_score - state.node_score
assert node is not None
return (score + node.output_score, node)
# The matched node of current step, will only return the node with
# longest phrase if there are multiple phrases matches this step.
# None if no matched phrase.
matched_node = (
node if node.is_end else (node.output if node.output is not None else None)
)
if not strict_mode and node.output_score != 0:
# output_score != 0 means at least on phrase matched
assert matched_node is not None
output_score = (
node.node_score
if node.is_end
else (
node.node_score if node.output is None else node.output.node_score
)
)
return (score + output_score - node.node_score, self.root, matched_node)
assert (node.output_score != 0 and matched_node is not None) or (
node.output_score == 0 and matched_node is None
), (
node.output_score,
matched_node,
)
return (score + node.output_score, node, matched_node)
def is_matched(self, state: ContextState) -> Tuple[bool, ContextState]:
"""Whether current state matches any phrase (i.e. current state is the
end state or the output of current state is not None.
Args:
state:
The given state(trie node).
Returns:
Return a tuple of status and matched state.
"""
if state.is_end:
return True, state
else:
if state.output is not None:
return True, state.output
return False, None
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
@ -366,7 +480,7 @@ class ContextGraph:
return dot
def _test(queries, score):
def _test(queries, score, strict_mode):
contexts_str = [
"S",
"HE",
@ -381,11 +495,15 @@ def _test(queries, score):
# test default score (1)
contexts = []
scores = []
phrases = []
for s in contexts_str:
contexts.append(([ord(x) for x in s], score))
contexts.append([ord(x) for x in s])
scores.append(round(score / len(s), 2))
phrases.append(s)
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
context_graph.build(token_ids=contexts, scores=scores, phrases=phrases)
symbol_table = {}
for contexts in contexts_str:
@ -402,7 +520,9 @@ def _test(queries, score):
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
score, state, phrase = context_graph.forward_one_step(
state, ord(q), strict_mode
)
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
@ -427,9 +547,22 @@ if __name__ == "__main__":
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
_test(queries, 0)
_test(queries, 0, True)
# test custom score (5)
queries = {
"HEHERSHE": 7, # "HE", "HE", "S", "HE"
"HERSHE": 5, # "HE", "S", "HE"
"HISHE": 5, # "HIS", "HE"
"SHED": 3, # "S", "HE"
"SHELF": 3, # "S", "HE"
"HELL": 2, # "HE"
"HELLO": 2, # "HE"
"DHRHISQ": 3, # "HIS"
"THEN": 2, # "HE"
}
_test(queries, 0, False)
# test custom score
# S : 5
# HE : 5 (2.5 + 2.5)
# SHE : 8.34 (5 + 1.67 + 1.67)
@ -450,4 +583,17 @@ if __name__ == "__main__":
"THEN": 5, # "HE"
}
_test(queries, 5)
_test(queries, 5, True)
queries = {
"HEHERSHE": 20, # "HE", "HE", "S", "HE"
"HERSHE": 15, # "HE", "S", "HE"
"HISHE": 10.84, # "HIS", "HE"
"SHED": 10, # "S", "HE"
"SHELF": 10, # "S", "HE"
"HELL": 5, # "HE"
"HELLO": 5, # "HE"
"DHRHISQ": 5.84, # "HIS"
"THEN": 5, # "HE"
}
_test(queries, 5, False)

View File

@ -159,7 +159,7 @@ class LmScorer(torch.nn.Module):
"""
if lm_type == "rnn":
model = RnnLmModel(
vocab_size=params.vocab_size,
vocab_size=params.lm_vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
@ -183,7 +183,7 @@ class LmScorer(torch.nn.Module):
elif lm_type == "transformer":
model = TransformerLM(
vocab_size=params.vocab_size,
vocab_size=params.lm_vocab_size,
d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward,

View File

@ -28,6 +28,8 @@ from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from pypinyin import pinyin, lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals
from shutil import copyfile
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
@ -327,6 +329,19 @@ def encode_supervisions_otc(
return supervision_segments, res, sorted_ids, sorted_verbatim_texts
@dataclass
class KeywordResult:
# timestamps[k] contains the frame number on which tokens[k]
# is decoded
timestamps: List[int]
# hyps is the keyword, i.e., word IDs or token IDs
hyps: List[int]
# The triggered phrase
phrase: str
@dataclass
class DecodingResults:
# timestamps[i][k] contains the frame number on which tokens[i][k]
@ -1583,6 +1598,87 @@ def load_averaged_model(
return model
def text_to_pinyin(
txt: str, mode: str = "full_with_tone", errors: str = "default"
) -> List[str]:
"""
Convert a Chinese text (might contain some latin characters) to pinyin sequence.
Args:
txt:
The input Chinese text.
mode:
The style of the output pinyin, should be:
full_with_tone : zhōng guó
full_no_tone : zhong guo
partial_with_tone : zh ōng g
partial_no_tone : zh ong g uo
errors:
How to handle the characters (latin) that has no pinyin.
default : output the same as input.
split : split into single characters (i.e. alphabets)
Return:
Return a list of str.
Examples:
txt: 想吃KFC
output: ['xiǎng', 'chī', 'KFC'] # mode=full_with_tone; errors=default
output: ['xiǎng', 'chī', 'K', 'F', 'C'] # mode=full_with_tone; errors=split
output: ['xiang', 'chi', 'KFC'] # mode=full_no_tone; errors=default
output: ['xiang', 'chi', 'K', 'F', 'C'] # mode=full_no_tone; errors=split
output: ['x', 'iǎng', 'ch', 'ī', 'KFC'] # mode=partial_with_tone; errors=default
output: ['x', 'iang', 'ch', 'i', 'KFC'] # mode=partial_no_tone; errors=default
"""
assert mode in (
"full_with_tone",
"full_no_tone",
"partial_no_tone",
"partial_with_tone",
), mode
assert errors in ("default", "split"), errors
txt = txt.strip()
res = []
if "full" in mode:
if errors == "default":
py = pinyin(txt) if mode == "full_with_tone" else lazy_pinyin(txt)
else:
py = (
pinyin(txt, errors=lambda x: list(x))
if mode == "full_with_tone"
else lazy_pinyin(txt, errors=lambda x: list(x))
)
res = [x[0] for x in py] if mode == "full_with_tone" else py
else:
if errors == "default":
py = pinyin(txt) if mode == "partial_with_tone" else lazy_pinyin(txt)
else:
py = (
pinyin(txt, errors=lambda x: list(x))
if mode == "partial_with_tone"
else lazy_pinyin(txt, errors=lambda x: list(x))
)
py = [x[0] for x in py] if mode == "partial_with_tone" else py
for x in py:
initial = to_initials(x, strict=False)
final = (
to_finals(x, strict=False)
if mode == "partial_no_tone"
else to_finals_tone(x, strict=False)
)
if initial == "" and final == "":
res.append(x)
else:
if initial != "":
res.append(initial)
if final != "":
res.append(final)
return res
def tokenize_by_bpe_model(
sp: spm.SentencePieceProcessor,
txt: str,

View File

@ -18,6 +18,7 @@ git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11
kaldialign==0.7.1
num2words
pypinyin==0.50.0
sentencepiece==0.1.96
tensorboard==2.8.0
typeguard==2.13.3

View File

@ -4,9 +4,10 @@ kaldialign
num2words
kaldi-decoder
sentencepiece>=0.1.96
pypinyin==0.50.0
tensorboard
typeguard
dill
black==22.3.0
onnx==1.15.0
onnxruntime==1.16.3
onnxruntime==1.16.3