add the pruned_transducer_stateless7_streaming recipe for commonvoice (#1018)

* add the pruned_transducer_stateless7_streaming recipe for commonvoice

* fix the symlinks

* Update RESULTS.md
This commit is contained in:
lishaojie 2023-11-09 22:07:28 +08:00 committed by GitHub
parent 231bbcd2b6
commit 1b2e99d374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 6260 additions and 840 deletions

View File

@ -57,3 +57,28 @@ Pretrained model is available at
The tensorboard log for training is available at
<https://tensorboard.dev/experiment/j4pJQty6RMOkMJtRySREKw/>
### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming)
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
See #1018 for more details.
Number of model parameters: 70369391, i.e., 70.37 M
The best WER for Common Voice French 12.0 (cv-corpus-12.0-2022-12-07/fr) is below:
Results are:
| decoding method | Test |
|----------------------|-------|
| greedy search | 9.95 |
| modified beam search | 9.57 |
| fast beam search | 9.67 |
Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice.
Detailed experimental results and Pretrained model are available at
<https://huggingface.co/shaojieli/icefall-asr-commonvoice-fr-pruned-transducer-stateless7-streaming-2023-04-02>

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -56,8 +56,8 @@ def get_args():
def compute_fbank_commonvoice_dev_test(language: str):
src_dir = Path(f"data/{language}/manifests")
output_dir = Path(f"data/{language}/fbank")
num_workers = 42
batch_duration = 600
num_workers = 16
batch_duration = 200
subsets = ("dev", "test")

View File

@ -43,9 +43,13 @@ def get_args():
return parser.parse_args()
def normalize_text(utt: str) -> str:
def normalize_text(utt: str, language: str) -> str:
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
return re.sub(r"[^a-zA-Z\s']", "", utt).upper()
utt = re.sub("", "'", utt)
if language == "en":
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
if language == "fr":
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
def preprocess_commonvoice(
@ -94,7 +98,7 @@ def preprocess_commonvoice(
for sup in m["supervisions"]:
text = str(sup.text)
orig_text = text
sup.text = normalize_text(sup.text)
sup.text = normalize_text(sup.text, language)
text = str(sup.text)
if len(orig_text) != len(text):
logging.info(

View File

@ -36,8 +36,8 @@ num_splits=1000
# - speech
dl_dir=$PWD/download
release=cv-corpus-13.0-2023-03-09
lang=en
release=cv-corpus-12.0-2022-12-07
lang=fr
. shared/parse_options.sh || exit 1
@ -146,7 +146,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
./local/compute_fbank_commonvoice_splits.py \
--num-workers $nj \
--batch-duration 600 \
--batch-duration 200 \
--start 0 \
--num-splits $num_splits \
--language $lang
@ -189,7 +189,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/words.txt ]; then
cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_dir/words.txt
@ -216,14 +216,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
}' > $lang_dir/words || exit 1;
mv $lang_dir/words $lang_dir/words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
@ -250,3 +250,55 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
done
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/${lang}/lang_bpe_${vocab_size}
mkdir -p $lang_dir/lm
#3-gram used in building HLG, 4-gram used for LM rescoring
for ngram in 3 4; do
if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order ${ngram} \
-text $lang_dir/transcript_words.txt \
-lm $lang_dir/lm/${ngram}gram.arpa
fi
if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=${ngram} \
$lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt
fi
done
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Compile HLG"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/${lang}/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
done
fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Compile LG"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/${lang}/lang_bpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir
done
fi

View File

@ -0,0 +1,9 @@
This recipe implements Streaming Zipformer-Transducer model.
See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
[./emformer.py](./emformer.py) and [./train.py](./train.py)
are basically the same as
[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py

View File

@ -0,0 +1,422 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class CommonVoiceAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--language",
type=str,
default="fr",
help="""Language of Common Voice""",
)
group.add_argument(
"--cv-manifest-dir",
type=Path,
default=Path("data/fr/fbank"),
help="Path to directory with CommonVoice train/dev/test cuts.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,810 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from commonvoice_fr import CommonVoiceAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
feature_lens += 30
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, 30),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
# errs_info = (
# params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
# )
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{key}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
CommonVoiceAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
model.encoder.decode_chunk_size,
params.decode_chunk_len,
)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
commonvoice = CommonVoiceAsrDataModule(args)
test_cuts = commonvoice.test_cuts()
test_dl = commonvoice.test_dataloaders(test_cuts)
test_sets = "test-cv"
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_sets,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,281 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--epoch 28 \
--avg 15 \
--use-averaged-model True \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--iter 22000 \
--avg 5 \
--use-averaged-model True \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--epoch 28 \
--avg 15 \
--use-averaged-model False \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
--iter 22000 \
--avg 5 \
--use-averaged-model False \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
"""
import argparse
from pathlib import Path
from typing import Dict, List
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model."
"If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
print("Script started")
device = torch.device("cpu")
print(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
print("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
print(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
print(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
print(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = (
params.exp_dir
/ f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt"
)
torch.save({"model": model.state_dict()}, filename)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
print(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt"
)
torch.save({"model": model.state_dict()}, filename)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
print("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/joiner.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py

View File

@ -0,0 +1,612 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding_method greedy_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from commonvoice_fr import CommonVoiceAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states, unstack_states
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
# factor in encoders is 8.
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
tail_length = 23
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
idx = 0
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = model.encoder.get_init_state(device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
if audio.max() > 1 or audio.min() < -1:
audio = audio / max(abs(audio.max()), abs(audio.min()))
print(audio)
print(audio.max())
print(audio.min())
print(cut)
idx += 1
print(idx)
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
CommonVoiceAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
commonvoice = CommonVoiceAsrDataModule(args)
test_cuts = commonvoice.test_cuts()
test_sets = "test-cv"
results_dict = decode_dataset(
cuts=test_cuts,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_sets,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,150 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless7_streaming/test_model.py
"""
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,4,3,2,4"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 512
params.joiner_dim = 512
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# Test jit script
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
print("Using torch.jit.script")
model = torch.jit.script(model)
def test_model_jit_trace():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,4,3,2,4"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 512
params.joiner_dim = 512
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
convert_scaled_to_non_scaled(model, inplace=True)
# Test encoder
def _test_encoder():
encoder = model.encoder
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
encoder.decode_chunk_size,
params.decode_chunk_len,
)
T = params.decode_chunk_len + 7
x = torch.zeros(1, T, 80, dtype=torch.float32)
x_lens = torch.full((1,), T, dtype=torch.int32)
states = encoder.get_init_state(device=x.device)
encoder.__class__.forward = encoder.__class__.streaming_forward
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
states1 = encoder.get_init_state(device=x.device)
states2 = traced_encoder.get_init_state(device=x.device)
for i in range(5):
x = torch.randn(1, T, 80, dtype=torch.float32)
x_lens = torch.full((1,), T, dtype=torch.int32)
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
y2, _, states2 = traced_encoder(x, x_lens, states2)
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
# Test decoder
def _test_decoder():
decoder = model.decoder
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
d1 = decoder(y, need_pad)
d2 = traced_decoder(y, need_pad)
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
# Test joiner
def _test_joiner():
joiner = model.joiner
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
j1 = joiner(encoder_out, decoder_out)
j2 = traced_joiner(encoder_out, decoder_out)
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
_test_encoder()
_test_decoder()
_test_joiner()
def main():
test_model()
test_model_jit_trace()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py

View File

@ -1,102 +0,0 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input an FST in k2 format and convert it
to an FST in OpenFST format.
The generated FST is saved into a binary file and its type is
StdVectorFst.
Usage examples:
(1) Convert an acceptor
./convert-k2-to-openfst.py in.pt binary.fst
(2) Convert a transducer
./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
"""
import argparse
import logging
from pathlib import Path
import k2
import kaldifst.utils
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--olabels",
type=str,
default=None,
help="""If not empty, the input FST is assumed to be a transducer
and we use its attribute specified by "olabels" as the output labels.
""",
)
parser.add_argument(
"input_filename",
type=str,
help="Path to the input FST in k2 format",
)
parser.add_argument(
"output_filename",
type=str,
help="Path to the output FST in OpenFst format",
)
return parser.parse_args()
def main():
args = get_args()
logging.info(f"{vars(args)}")
input_filename = args.input_filename
output_filename = args.output_filename
olabels = args.olabels
if Path(output_filename).is_file():
logging.info(f"{output_filename} already exists - skipping")
return
assert Path(input_filename).is_file(), f"{input_filename} does not exist"
logging.info(f"Loading {input_filename}")
k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
if olabels:
assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
p = Path(output_filename).parent
if not p.is_dir():
logging.info(f"Creating {p}")
p.mkdir(parents=True)
logging.info("Converting (May take some time if the input FST is large)")
fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
logging.info(f"Saving to {output_filename}")
fst.write(output_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/shared/convert-k2-to-openfst.py

View File

@ -1,630 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./ngram_entropy_pruning.py \
-threshold 1e-8 \
-lm download/lm/4gram.arpa \
-write-lm download/lm/4gram_pruned_1e8.arpa
This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
in the same way as SRILM.
"""
import argparse
import gzip
import logging
import math
import re
from collections import OrderedDict, defaultdict
from enum import Enum, unique
from io import StringIO
parser = argparse.ArgumentParser(
description="""
Prune an n-gram language model based on the relative entropy
between the original and the pruned model, based on Andreas Stolcke's paper.
An n-gram entry is removed, if the removal causes (training set) perplexity
of the model to increase by less than threshold relative.
The command takes an arpa file and a pruning threshold as input,
and outputs a pruned arpa file.
"""
)
parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
parser.add_argument(
"-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
)
parser.add_argument(
"-minorder",
type=int,
default=1,
help="The minorder parameter limits pruning to ngrams of that length and above.",
)
parser.add_argument(
"-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
)
parser.add_argument(
"-verbose",
type=int,
default=2,
choices=[0, 1, 2, 3, 4, 5],
help="Verbose level, where 0 is most noisy; 5 is most silent",
)
args = parser.parse_args()
default_encoding = args.encoding
logging.basicConfig(
format="%(asctime)s%(levelname)s%(funcName)s:%(lineno)d%(message)s",
level=args.verbose * 10,
)
class Context(dict):
"""
This class stores data for a context h.
It behaves like a python dict object, except that it has several
additional attributes.
"""
def __init__(self):
super().__init__()
self.log_bo = None
class Arpa:
"""
This is a class that implement the data structure of an APRA LM.
It (as well as some other classes) is modified based on the library
by Stefan Fischer:
https://github.com/sfischer13/python-arpa
"""
UNK = "<unk>"
SOS = "<s>"
EOS = "</s>"
FLOAT_NDIGITS = 7
base = 10
@staticmethod
def _check_input(my_input):
if not my_input:
raise ValueError
elif isinstance(my_input, tuple):
return my_input
elif isinstance(my_input, list):
return tuple(my_input)
elif isinstance(my_input, str):
return tuple(my_input.strip().split(" "))
else:
raise ValueError
@staticmethod
def _check_word(input_word):
if not isinstance(input_word, str):
raise ValueError
if " " in input_word:
raise ValueError
def _replace_unks(self, words):
return tuple((w if w in self else self._unk) for w in words)
def __init__(self, path=None, encoding=None, unk=None):
self._counts = OrderedDict()
self._ngrams = (
OrderedDict()
) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
self._vocabulary = set()
if unk is None:
self._unk = self.UNK
if path is not None:
self.loadf(path, encoding)
def __contains__(self, ngram):
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
def contains_word(self, word):
self._check_word(word)
return word in self._vocabulary
def add_count(self, order, count):
self._counts[order] = count
self._ngrams[order - 1] = defaultdict(Context)
def update_counts(self):
for order in range(1, self.order() + 1):
count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
if count > 0:
self._counts[order] = count
def add_entry(self, ngram, p, bo=None, order=None):
# Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
# Note that p and bo here are in fact in the log domain (self.base = 10)
h_context = self._ngrams[len(h)][h]
h_context[w] = p
if bo is not None:
self._ngrams[len(ngram)][ngram].log_bo = bo
for word in ngram:
self._vocabulary.add(word)
def counts(self):
return sorted(self._counts.items())
def order(self):
return max(self._counts.keys(), default=None)
def vocabulary(self, sort=True):
if sort:
return sorted(self._vocabulary)
else:
return self._vocabulary
def _entries(self, order):
return (
self._entry(h, w)
for h, wlist in self._ngrams[order - 1].items()
for w in wlist
)
def _entry(self, h, w):
# return the entry for the ngram (h, w)
ngram = h + (w,)
log_p = self._ngrams[len(h)][h][w]
log_bo = self._log_bo(ngram)
if log_bo is not None:
return (
round(log_p, self.FLOAT_NDIGITS),
ngram,
round(log_bo, self.FLOAT_NDIGITS),
)
else:
return round(log_p, self.FLOAT_NDIGITS), ngram
def _log_bo(self, ngram):
if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
return self._ngrams[len(ngram)][ngram].log_bo
else:
return None
def _log_p(self, ngram):
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
return self._ngrams[len(h)][h][w]
else:
return None
def log_p_raw(self, ngram):
log_p = self._log_p(ngram)
if log_p is not None:
return log_p
else:
if len(ngram) == 1:
raise KeyError
else:
log_bo = self._log_bo(ngram[:-1])
if log_bo is None:
log_bo = 0
return log_bo + self.log_p_raw(ngram[1:])
def log_joint_prob(self, sequence):
# Compute the joint prob of the sequence based on the chain rule
# Note that sequence should be a tuple of strings
#
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
log_joint_p = 0
seq = sequence
while len(seq) > 0:
log_joint_p += self.log_p_raw(seq)
seq = seq[:-1]
# If we're computing the marginal probability of the unigram
# <s> context we have to look up </s> instead since the former
# has prob = 0.
if len(seq) == 1 and seq[0] == self.SOS:
seq = (self.EOS,)
return log_joint_p
def set_new_context(self, h):
old_context = self._ngrams[len(h)][h]
self._ngrams[len(h)][h] = Context()
return old_context
def log_p(self, ngram):
words = self._check_input(ngram)
if self._unk:
words = self._replace_unks(words)
return self.log_p_raw(words)
def log_s(self, sentence, sos=SOS, eos=EOS):
words = self._check_input(sentence)
if self._unk:
words = self._replace_unks(words)
if sos:
words = (sos,) + words
if eos:
words = words + (eos,)
result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
if sos:
result = result - self.log_p_raw(words[:1])
return result
def p(self, ngram):
return self.base ** self.log_p(ngram)
def s(self, sentence):
return self.base ** self.log_s(sentence)
def write(self, fp):
fp.write("\n\\data\\\n")
for order, count in self.counts():
fp.write("ngram {}={}\n".format(order, count))
fp.write("\n")
for order, _ in self.counts():
fp.write("\\{}-grams:\n".format(order))
for e in self._entries(order):
prob = e[0]
ngram = " ".join(e[1])
if len(e) == 2:
fp.write("{}\t{}\n".format(prob, ngram))
elif len(e) == 3:
backoff = e[2]
fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
else:
raise ValueError
fp.write("\n")
fp.write("\\end\\\n")
class ArpaParser:
"""
This is a class that implement a parser of an arpa file
"""
@unique
class State(Enum):
DATA = 1
COUNT = 2
HEADER = 3
ENTRY = 4
re_count = re.compile(r"^ngram (\d+)=(\d+)$")
re_header = re.compile(r"^\\(\d+)-grams:$")
re_entry = re.compile(
"^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
"\t"
"(\\S+( \\S+)*)"
"(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
)
def _parse(self, fp):
self._result = []
self._state = self.State.DATA
self._tmp_model = None
self._tmp_order = None
for line in fp:
line = line.strip()
if self._state == self.State.DATA:
self._data(line)
elif self._state == self.State.COUNT:
self._count(line)
elif self._state == self.State.HEADER:
self._header(line)
elif self._state == self.State.ENTRY:
self._entry(line)
if self._state != self.State.DATA:
raise Exception(line)
return self._result
def _data(self, line):
if line == "\\data\\":
self._state = self.State.COUNT
self._tmp_model = Arpa()
else:
pass # skip comment line
def _count(self, line):
match = self.re_count.match(line)
if match:
order = match.group(1)
count = match.group(2)
self._tmp_model.add_count(int(order), int(count))
elif not line:
self._state = self.State.HEADER # there are no counts
else:
raise Exception(line)
def _header(self, line):
match = self.re_header.match(line)
if match:
self._state = self.State.ENTRY
self._tmp_order = int(match.group(1))
elif line == "\\end\\":
self._result.append(self._tmp_model)
self._state = self.State.DATA
self._tmp_model = None
self._tmp_order = None
elif not line:
pass # skip empty line
else:
raise Exception(line)
def _entry(self, line):
match = self.re_entry.match(line)
if match:
p = self._float_or_int(match.group(1))
ngram = tuple(match.group(4).split(" "))
bo_match = match.group(7)
bo = self._float_or_int(bo_match) if bo_match else None
self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
elif not line:
self._state = self.State.HEADER # last entry
else:
raise Exception(line)
@staticmethod
def _float_or_int(s):
f = float(s)
i = int(f)
if str(i) == s: # don't drop trailing ".0"
return i
else:
return f
def load(self, fp):
"""Deserialize fp (a file-like object) to a Python object."""
return self._parse(fp)
def loadf(self, path, encoding=None):
"""Deserialize path (.arpa, .gz) to a Python object."""
path = str(path)
if path.endswith(".gz"):
with gzip.open(path, mode="rt", encoding=encoding) as f:
return self.load(f)
else:
with open(path, mode="rt", encoding=encoding) as f:
return self.load(f)
def loads(self, s):
"""Deserialize s (a str) to a Python object."""
with StringIO(s) as f:
return self.load(f)
def dump(self, obj, fp):
"""Serialize obj to fp (a file-like object) in ARPA format."""
obj.write(fp)
def dumpf(self, obj, path, encoding=None):
"""Serialize obj to path in ARPA format (.arpa, .gz)."""
path = str(path)
if path.endswith(".gz"):
with gzip.open(path, mode="wt", encoding=encoding) as f:
return self.dump(obj, f)
else:
with open(path, mode="wt", encoding=encoding) as f:
self.dump(obj, f)
def dumps(self, obj):
"""Serialize obj to an ARPA formatted str."""
with StringIO() as f:
self.dump(obj, f)
return f.getvalue()
def add_log_p(prev_log_sum, log_p, base):
return math.log(base**log_p + base**prev_log_sum, base)
def compute_numerator_denominator(lm, h):
log_sum_seen_h = -math.inf
log_sum_seen_h_lower = -math.inf
base = lm.base
for w, log_p in lm._ngrams[len(h)][h].items():
log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
ngram = h + (w,)
log_p_lower = lm.log_p_raw(ngram[1:])
log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
numerator = 1.0 - base**log_sum_seen_h
denominator = 1.0 - base**log_sum_seen_h_lower
return numerator, denominator
def prune(lm, threshold, minorder):
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
for i in range(
lm.order(), max(minorder - 1, 1), -1
): # i is the order of the ngram (h, w)
logging.info("processing %d-grams ..." % i)
count_pruned_ngrams = 0
h_dict = lm._ngrams[i - 1]
for h in list(h_dict.keys()):
# old backoff weight, BOW(h)
log_bow = lm._log_bo(h)
if log_bow is None:
log_bow = 0
# Compute numerator and denominator of the backoff weight,
# so that we can quickly compute the BOW adjustment due to
# leaving out one prob.
numerator, denominator = compute_numerator_denominator(lm, h)
# assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
# Compute the marginal probability of the context, P(h)
h_log_p = lm.log_joint_prob(h)
all_pruned = True
pruned_w_set = set()
for w, log_p in h_dict[h].items():
ngram = h + (w,)
# lower-order estimate for ngramProb, P(w|h')
backoff_prob = lm.log_p_raw(ngram[1:])
# Compute BOW after removing ngram, BOW'(h)
new_log_bow = math.log(
numerator + lm.base**log_p, lm.base
) - math.log(denominator + lm.base**backoff_prob, lm.base)
# Compute change in entropy due to removal of ngram
delta_prob = backoff_prob + new_log_bow - log_p
delta_entropy = -(lm.base**h_log_p) * (
(lm.base**log_p) * delta_prob
+ numerator * (new_log_bow - log_bow)
)
# compute relative change in model (training set) perplexity
perp_change = lm.base**delta_entropy - 1.0
pruned = threshold > 0 and perp_change < threshold
# Make sure we don't prune ngrams whose backoff nodes are needed
if (
pruned
and len(ngram) in lm._ngrams
and len(lm._ngrams[len(ngram)][ngram]) > 0
):
pruned = False
logging.debug(
"CONTEXT "
+ str(h)
+ " WORD "
+ w
+ " CONTEXTPROB %f " % h_log_p
+ " OLDPROB %f " % log_p
+ " NEWPROB %f " % (backoff_prob + new_log_bow)
+ " DELTA-H %f " % delta_entropy
+ " DELTA-LOGP %f " % delta_prob
+ " PPL-CHANGE %f " % perp_change
+ " PRUNED "
+ str(pruned)
)
if pruned:
pruned_w_set.add(w)
count_pruned_ngrams += 1
else:
all_pruned = False
# If we removed all ngrams for this context we can
# remove the context itself, but only if the present
# context is not a prefix to a longer one.
if all_pruned and len(pruned_w_set) == len(h_dict[h]):
del h_dict[
h
] # this context h is no longer needed, as its ngram prob is stored at its own context h'
elif len(pruned_w_set) > 0:
# The pruning for this context h is actually done here
old_context = lm.set_new_context(h)
for w, p_w in old_context.items():
if w not in pruned_w_set:
lm.add_entry(
h + (w,), p_w
) # the entry hw is stored at the context h
# We need to recompute the back-off weight, but
# this can only be done after completing the pruning
# of the lower-order ngrams.
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
# recompute backoff weights
for i in range(
max(minorder - 1, 1) + 1, lm.order() + 1
): # be careful of this order: from low- to high-order
for h in lm._ngrams[i - 1]:
numerator, denominator = compute_numerator_denominator(lm, h)
new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
lm._ngrams[len(h)][h].log_bo = new_log_bow
# update counts
lm.update_counts()
return
def check_h_is_valid(lm, h):
sum_under_h = sum(
[lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
)
if abs(sum_under_h - 1.0) > 1e-6:
logging.info("warning: %s %f" % (str(h), sum_under_h))
return False
else:
return True
def validate_lm(lm):
# sanity check if the conditional probability sums to one under each context h
for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
logging.info("validating %d-grams ..." % i)
h_dict = lm._ngrams[i - 1]
for h in h_dict.keys():
check_h_is_valid(lm, h)
def compare_two_apras(path1, path2):
pass
if __name__ == "__main__":
# load an arpa file
logging.info("Loading the arpa file from %s" % args.lm)
parser = ArpaParser()
models = parser.loadf(args.lm, encoding=default_encoding)
lm = models[0] # ARPA files may contain several models.
logging.info("Stats before pruning:")
for i, cnt in lm.counts():
logging.info("ngram %d=%d" % (i, cnt))
# prune it, the language model will be modified in-place
logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
prune(lm, args.threshold, args.minorder)
# validate_lm(lm)
# write the arpa language model to a file
logging.info("Stats after pruning:")
for i, cnt in lm.counts():
logging.info("ngram %d=%d" % (i, cnt))
logging.info("Saving the pruned arpa file to %s" % args.write_lm)
parser.dumpf(lm, args.write_lm, encoding=default_encoding)
logging.info("Done.")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/shared/ngram_entropy_pruning.py

View File

@ -1,97 +0,0 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

View File

@ -0,0 +1 @@
../../../librispeech/ASR/shared/parse_options.sh