mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
add the pruned_transducer_stateless7_streaming recipe for commonvoice (#1018)
* add the pruned_transducer_stateless7_streaming recipe for commonvoice * fix the symlinks * Update RESULTS.md
This commit is contained in:
parent
231bbcd2b6
commit
1b2e99d374
@ -57,3 +57,28 @@ Pretrained model is available at
|
||||
|
||||
The tensorboard log for training is available at
|
||||
<https://tensorboard.dev/experiment/j4pJQty6RMOkMJtRySREKw/>
|
||||
|
||||
|
||||
### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming)
|
||||
|
||||
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
|
||||
|
||||
See #1018 for more details.
|
||||
|
||||
Number of model parameters: 70369391, i.e., 70.37 M
|
||||
|
||||
The best WER for Common Voice French 12.0 (cv-corpus-12.0-2022-12-07/fr) is below:
|
||||
|
||||
Results are:
|
||||
|
||||
| decoding method | Test |
|
||||
|----------------------|-------|
|
||||
| greedy search | 9.95 |
|
||||
| modified beam search | 9.57 |
|
||||
| fast beam search | 9.67 |
|
||||
|
||||
Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice.
|
||||
|
||||
Detailed experimental results and Pretrained model are available at
|
||||
<https://huggingface.co/shaojieli/icefall-asr-commonvoice-fr-pruned-transducer-stateless7-streaming-2023-04-02>
|
||||
|
||||
|
1
egs/commonvoice/ASR/local/compile_hlg.py
Symbolic link
1
egs/commonvoice/ASR/local/compile_hlg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_hlg.py
|
1
egs/commonvoice/ASR/local/compile_lg.py
Symbolic link
1
egs/commonvoice/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
@ -56,8 +56,8 @@ def get_args():
|
||||
def compute_fbank_commonvoice_dev_test(language: str):
|
||||
src_dir = Path(f"data/{language}/manifests")
|
||||
output_dir = Path(f"data/{language}/fbank")
|
||||
num_workers = 42
|
||||
batch_duration = 600
|
||||
num_workers = 16
|
||||
batch_duration = 200
|
||||
|
||||
subsets = ("dev", "test")
|
||||
|
||||
|
@ -43,9 +43,13 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def normalize_text(utt: str) -> str:
|
||||
def normalize_text(utt: str, language: str) -> str:
|
||||
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
|
||||
return re.sub(r"[^a-zA-Z\s']", "", utt).upper()
|
||||
utt = re.sub("’", "'", utt)
|
||||
if language == "en":
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
if language == "fr":
|
||||
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
|
||||
|
||||
|
||||
def preprocess_commonvoice(
|
||||
@ -94,7 +98,7 @@ def preprocess_commonvoice(
|
||||
for sup in m["supervisions"]:
|
||||
text = str(sup.text)
|
||||
orig_text = text
|
||||
sup.text = normalize_text(sup.text)
|
||||
sup.text = normalize_text(sup.text, language)
|
||||
text = str(sup.text)
|
||||
if len(orig_text) != len(text):
|
||||
logging.info(
|
||||
|
@ -36,8 +36,8 @@ num_splits=1000
|
||||
# - speech
|
||||
|
||||
dl_dir=$PWD/download
|
||||
release=cv-corpus-13.0-2023-03-09
|
||||
lang=en
|
||||
release=cv-corpus-12.0-2022-12-07
|
||||
lang=fr
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
@ -146,7 +146,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
|
||||
./local/compute_fbank_commonvoice_splits.py \
|
||||
--num-workers $nj \
|
||||
--batch-duration 600 \
|
||||
--batch-duration 200 \
|
||||
--start 0 \
|
||||
--num-splits $num_splits \
|
||||
--language $lang
|
||||
@ -189,7 +189,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
|
||||
sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/words.txt ]; then
|
||||
cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' > $lang_dir/words.txt
|
||||
@ -216,14 +216,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
}' > $lang_dir/words || exit 1;
|
||||
mv $lang_dir/words $lang_dir/words.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
|
||||
@ -250,3 +250,55 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/${lang}/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir/lm
|
||||
#3-gram used in building HLG, 4-gram used for LM rescoring
|
||||
for ngram in 3 4; do
|
||||
if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order ${ngram} \
|
||||
-text $lang_dir/transcript_words.txt \
|
||||
-lm $lang_dir/lm/${ngram}gram.arpa
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=${ngram} \
|
||||
$lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt
|
||||
fi
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Compile HLG"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/${lang}/lang_bpe_${vocab_size}
|
||||
./local/compile_hlg.py --lang-dir $lang_dir
|
||||
|
||||
# Note If ./local/compile_hlg.py throws OOM,
|
||||
# please switch to the following command
|
||||
#
|
||||
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
# Compile LG for RNN-T fast_beam_search decoding
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 12: Compile LG"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/${lang}/lang_bpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
@ -0,0 +1,9 @@
|
||||
This recipe implements Streaming Zipformer-Transducer model.
|
||||
|
||||
See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
|
||||
|
||||
[./emformer.py](./emformer.py) and [./train.py](./train.py)
|
||||
are basically the same as
|
||||
[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
|
||||
The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
|
||||
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py
|
@ -0,0 +1,422 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class CommonVoiceAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default="fr",
|
||||
help="""Language of Common Voice""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--cv-manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fr/fbank"),
|
||||
help="Path to directory with CommonVoice train/dev/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
|
||||
)
|
810
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
810
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
@ -0,0 +1,810 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from commonvoice_fr import CommonVoiceAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += 30
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, 30),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
|
||||
# errs_info = (
|
||||
# params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
# )
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{key}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
CommonVoiceAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
commonvoice = CommonVoiceAsrDataModule(args)
|
||||
|
||||
test_cuts = commonvoice.test_cuts()
|
||||
|
||||
test_dl = commonvoice.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = "test-cv"
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_sets,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/decoder.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
|
1342
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
Executable file
1342
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--use-averaged-model True \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
|
||||
|
||||
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--iter 22000 \
|
||||
--avg 5 \
|
||||
--use-averaged-model True \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
|
||||
|
||||
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
|
||||
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
--iter 22000 \
|
||||
--avg 5 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model."
|
||||
"If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
print("Script started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
print(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
print("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
print(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
print(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
print(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
print(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/joiner.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/model.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/optim.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling_converter.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
|
612
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
612
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
@ -0,0 +1,612 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 2000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from commonvoice_fr import CommonVoiceAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states, unstack_states
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Supported decoding methods are:
|
||||
greedy_search
|
||||
modified_beam_search
|
||||
fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=32,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
decode_streams: List[DecodeStream],
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decode_streams:
|
||||
A List of DecodeStream, each belonging to a utterance.
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
||||
# factor in encoders is 8.
|
||||
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
||||
tail_length = 23
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
states = stack_states(states)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=processed_lens,
|
||||
streams=decode_streams,
|
||||
beam=params.beam,
|
||||
max_states=params.max_states,
|
||||
max_contexts=params.max_contexts,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
|
||||
states = unstack_states(new_states)
|
||||
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
return finished_streams
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
opts = FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
idx = 0
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
initial_states = model.encoder.get_init_state(device=device)
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
initial_states=initial_states,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
)
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
if audio.max() > 1 or audio.min() < -1:
|
||||
audio = audio / max(abs(audio.max()), abs(audio.min()))
|
||||
print(audio)
|
||||
print(audio.max())
|
||||
print(audio.min())
|
||||
print(cut)
|
||||
idx += 1
|
||||
print(idx)
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
CommonVoiceAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
commonvoice = CommonVoiceAsrDataModule(args)
|
||||
test_cuts = commonvoice.test_cuts()
|
||||
test_sets = "test-cv"
|
||||
|
||||
results_dict = decode_dataset(
|
||||
cuts=test_cuts,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_sets,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
150
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
150
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless7_streaming/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Test jit script
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
print("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
|
||||
|
||||
def test_model_jit_trace():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
# Test encoder
|
||||
def _test_encoder():
|
||||
encoder = model.encoder
|
||||
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
T = params.decode_chunk_len + 7
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
states = encoder.get_init_state(device=x.device)
|
||||
encoder.__class__.forward = encoder.__class__.streaming_forward
|
||||
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
|
||||
|
||||
states1 = encoder.get_init_state(device=x.device)
|
||||
states2 = traced_encoder.get_init_state(device=x.device)
|
||||
for i in range(5):
|
||||
x = torch.randn(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
|
||||
y2, _, states2 = traced_encoder(x, x_lens, states2)
|
||||
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
|
||||
|
||||
# Test decoder
|
||||
def _test_decoder():
|
||||
decoder = model.decoder
|
||||
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
|
||||
d1 = decoder(y, need_pad)
|
||||
d2 = traced_decoder(y, need_pad)
|
||||
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
|
||||
|
||||
# Test joiner
|
||||
def _test_joiner():
|
||||
joiner = model.joiner
|
||||
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
|
||||
j1 = joiner(encoder_out, decoder_out)
|
||||
j2 = traced_joiner(encoder_out, decoder_out)
|
||||
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
|
||||
|
||||
_test_encoder()
|
||||
_test_decoder()
|
||||
_test_joiner()
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
test_model_jit_trace()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1256
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
1256
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1257
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
Executable file
1257
egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
|
@ -1,102 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script takes as input an FST in k2 format and convert it
|
||||
to an FST in OpenFST format.
|
||||
|
||||
The generated FST is saved into a binary file and its type is
|
||||
StdVectorFst.
|
||||
|
||||
Usage examples:
|
||||
(1) Convert an acceptor
|
||||
|
||||
./convert-k2-to-openfst.py in.pt binary.fst
|
||||
|
||||
(2) Convert a transducer
|
||||
|
||||
./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import kaldifst.utils
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--olabels",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""If not empty, the input FST is assumed to be a transducer
|
||||
and we use its attribute specified by "olabels" as the output labels.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"input_filename",
|
||||
type=str,
|
||||
help="Path to the input FST in k2 format",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"output_filename",
|
||||
type=str,
|
||||
help="Path to the output FST in OpenFst format",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(f"{vars(args)}")
|
||||
|
||||
input_filename = args.input_filename
|
||||
output_filename = args.output_filename
|
||||
olabels = args.olabels
|
||||
|
||||
if Path(output_filename).is_file():
|
||||
logging.info(f"{output_filename} already exists - skipping")
|
||||
return
|
||||
|
||||
assert Path(input_filename).is_file(), f"{input_filename} does not exist"
|
||||
logging.info(f"Loading {input_filename}")
|
||||
k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
|
||||
if olabels:
|
||||
assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
|
||||
|
||||
p = Path(output_filename).parent
|
||||
if not p.is_dir():
|
||||
logging.info(f"Creating {p}")
|
||||
p.mkdir(parents=True)
|
||||
|
||||
logging.info("Converting (May take some time if the input FST is large)")
|
||||
fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
|
||||
logging.info(f"Saving to {output_filename}")
|
||||
fst.write(output_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
icefall/shared/convert-k2-to-openfst.py
Symbolic link
1
icefall/shared/convert-k2-to-openfst.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/shared/convert-k2-to-openfst.py
|
@ -1,630 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
./ngram_entropy_pruning.py \
|
||||
-threshold 1e-8 \
|
||||
-lm download/lm/4gram.arpa \
|
||||
-write-lm download/lm/4gram_pruned_1e8.arpa
|
||||
|
||||
This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
|
||||
This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
|
||||
in the same way as SRILM.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum, unique
|
||||
from io import StringIO
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
Prune an n-gram language model based on the relative entropy
|
||||
between the original and the pruned model, based on Andreas Stolcke's paper.
|
||||
An n-gram entry is removed, if the removal causes (training set) perplexity
|
||||
of the model to increase by less than threshold relative.
|
||||
|
||||
The command takes an arpa file and a pruning threshold as input,
|
||||
and outputs a pruned arpa file.
|
||||
"""
|
||||
)
|
||||
parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
|
||||
parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
|
||||
parser.add_argument(
|
||||
"-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-minorder",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The minorder parameter limits pruning to ngrams of that length and above.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-verbose",
|
||||
type=int,
|
||||
default=2,
|
||||
choices=[0, 1, 2, 3, 4, 5],
|
||||
help="Verbose level, where 0 is most noisy; 5 is most silent",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
default_encoding = args.encoding
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
|
||||
level=args.verbose * 10,
|
||||
)
|
||||
|
||||
|
||||
class Context(dict):
|
||||
"""
|
||||
This class stores data for a context h.
|
||||
It behaves like a python dict object, except that it has several
|
||||
additional attributes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.log_bo = None
|
||||
|
||||
|
||||
class Arpa:
|
||||
"""
|
||||
This is a class that implement the data structure of an APRA LM.
|
||||
It (as well as some other classes) is modified based on the library
|
||||
by Stefan Fischer:
|
||||
https://github.com/sfischer13/python-arpa
|
||||
"""
|
||||
|
||||
UNK = "<unk>"
|
||||
SOS = "<s>"
|
||||
EOS = "</s>"
|
||||
FLOAT_NDIGITS = 7
|
||||
base = 10
|
||||
|
||||
@staticmethod
|
||||
def _check_input(my_input):
|
||||
if not my_input:
|
||||
raise ValueError
|
||||
elif isinstance(my_input, tuple):
|
||||
return my_input
|
||||
elif isinstance(my_input, list):
|
||||
return tuple(my_input)
|
||||
elif isinstance(my_input, str):
|
||||
return tuple(my_input.strip().split(" "))
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _check_word(input_word):
|
||||
if not isinstance(input_word, str):
|
||||
raise ValueError
|
||||
if " " in input_word:
|
||||
raise ValueError
|
||||
|
||||
def _replace_unks(self, words):
|
||||
return tuple((w if w in self else self._unk) for w in words)
|
||||
|
||||
def __init__(self, path=None, encoding=None, unk=None):
|
||||
self._counts = OrderedDict()
|
||||
self._ngrams = (
|
||||
OrderedDict()
|
||||
) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
|
||||
self._vocabulary = set()
|
||||
if unk is None:
|
||||
self._unk = self.UNK
|
||||
|
||||
if path is not None:
|
||||
self.loadf(path, encoding)
|
||||
|
||||
def __contains__(self, ngram):
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
|
||||
|
||||
def contains_word(self, word):
|
||||
self._check_word(word)
|
||||
return word in self._vocabulary
|
||||
|
||||
def add_count(self, order, count):
|
||||
self._counts[order] = count
|
||||
self._ngrams[order - 1] = defaultdict(Context)
|
||||
|
||||
def update_counts(self):
|
||||
for order in range(1, self.order() + 1):
|
||||
count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
|
||||
if count > 0:
|
||||
self._counts[order] = count
|
||||
|
||||
def add_entry(self, ngram, p, bo=None, order=None):
|
||||
# Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
|
||||
# Note that p and bo here are in fact in the log domain (self.base = 10)
|
||||
h_context = self._ngrams[len(h)][h]
|
||||
h_context[w] = p
|
||||
if bo is not None:
|
||||
self._ngrams[len(ngram)][ngram].log_bo = bo
|
||||
|
||||
for word in ngram:
|
||||
self._vocabulary.add(word)
|
||||
|
||||
def counts(self):
|
||||
return sorted(self._counts.items())
|
||||
|
||||
def order(self):
|
||||
return max(self._counts.keys(), default=None)
|
||||
|
||||
def vocabulary(self, sort=True):
|
||||
if sort:
|
||||
return sorted(self._vocabulary)
|
||||
else:
|
||||
return self._vocabulary
|
||||
|
||||
def _entries(self, order):
|
||||
return (
|
||||
self._entry(h, w)
|
||||
for h, wlist in self._ngrams[order - 1].items()
|
||||
for w in wlist
|
||||
)
|
||||
|
||||
def _entry(self, h, w):
|
||||
# return the entry for the ngram (h, w)
|
||||
ngram = h + (w,)
|
||||
log_p = self._ngrams[len(h)][h][w]
|
||||
log_bo = self._log_bo(ngram)
|
||||
if log_bo is not None:
|
||||
return (
|
||||
round(log_p, self.FLOAT_NDIGITS),
|
||||
ngram,
|
||||
round(log_bo, self.FLOAT_NDIGITS),
|
||||
)
|
||||
else:
|
||||
return round(log_p, self.FLOAT_NDIGITS), ngram
|
||||
|
||||
def _log_bo(self, ngram):
|
||||
if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
|
||||
return self._ngrams[len(ngram)][ngram].log_bo
|
||||
else:
|
||||
return None
|
||||
|
||||
def _log_p(self, ngram):
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
|
||||
return self._ngrams[len(h)][h][w]
|
||||
else:
|
||||
return None
|
||||
|
||||
def log_p_raw(self, ngram):
|
||||
log_p = self._log_p(ngram)
|
||||
if log_p is not None:
|
||||
return log_p
|
||||
else:
|
||||
if len(ngram) == 1:
|
||||
raise KeyError
|
||||
else:
|
||||
log_bo = self._log_bo(ngram[:-1])
|
||||
if log_bo is None:
|
||||
log_bo = 0
|
||||
return log_bo + self.log_p_raw(ngram[1:])
|
||||
|
||||
def log_joint_prob(self, sequence):
|
||||
# Compute the joint prob of the sequence based on the chain rule
|
||||
# Note that sequence should be a tuple of strings
|
||||
#
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
|
||||
|
||||
log_joint_p = 0
|
||||
seq = sequence
|
||||
while len(seq) > 0:
|
||||
log_joint_p += self.log_p_raw(seq)
|
||||
seq = seq[:-1]
|
||||
|
||||
# If we're computing the marginal probability of the unigram
|
||||
# <s> context we have to look up </s> instead since the former
|
||||
# has prob = 0.
|
||||
if len(seq) == 1 and seq[0] == self.SOS:
|
||||
seq = (self.EOS,)
|
||||
|
||||
return log_joint_p
|
||||
|
||||
def set_new_context(self, h):
|
||||
old_context = self._ngrams[len(h)][h]
|
||||
self._ngrams[len(h)][h] = Context()
|
||||
return old_context
|
||||
|
||||
def log_p(self, ngram):
|
||||
words = self._check_input(ngram)
|
||||
if self._unk:
|
||||
words = self._replace_unks(words)
|
||||
return self.log_p_raw(words)
|
||||
|
||||
def log_s(self, sentence, sos=SOS, eos=EOS):
|
||||
words = self._check_input(sentence)
|
||||
if self._unk:
|
||||
words = self._replace_unks(words)
|
||||
if sos:
|
||||
words = (sos,) + words
|
||||
if eos:
|
||||
words = words + (eos,)
|
||||
result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
|
||||
if sos:
|
||||
result = result - self.log_p_raw(words[:1])
|
||||
return result
|
||||
|
||||
def p(self, ngram):
|
||||
return self.base ** self.log_p(ngram)
|
||||
|
||||
def s(self, sentence):
|
||||
return self.base ** self.log_s(sentence)
|
||||
|
||||
def write(self, fp):
|
||||
fp.write("\n\\data\\\n")
|
||||
for order, count in self.counts():
|
||||
fp.write("ngram {}={}\n".format(order, count))
|
||||
fp.write("\n")
|
||||
for order, _ in self.counts():
|
||||
fp.write("\\{}-grams:\n".format(order))
|
||||
for e in self._entries(order):
|
||||
prob = e[0]
|
||||
ngram = " ".join(e[1])
|
||||
if len(e) == 2:
|
||||
fp.write("{}\t{}\n".format(prob, ngram))
|
||||
elif len(e) == 3:
|
||||
backoff = e[2]
|
||||
fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
|
||||
else:
|
||||
raise ValueError
|
||||
fp.write("\n")
|
||||
fp.write("\\end\\\n")
|
||||
|
||||
|
||||
class ArpaParser:
|
||||
"""
|
||||
This is a class that implement a parser of an arpa file
|
||||
"""
|
||||
|
||||
@unique
|
||||
class State(Enum):
|
||||
DATA = 1
|
||||
COUNT = 2
|
||||
HEADER = 3
|
||||
ENTRY = 4
|
||||
|
||||
re_count = re.compile(r"^ngram (\d+)=(\d+)$")
|
||||
re_header = re.compile(r"^\\(\d+)-grams:$")
|
||||
re_entry = re.compile(
|
||||
"^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
|
||||
"\t"
|
||||
"(\\S+( \\S+)*)"
|
||||
"(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
|
||||
)
|
||||
|
||||
def _parse(self, fp):
|
||||
self._result = []
|
||||
self._state = self.State.DATA
|
||||
self._tmp_model = None
|
||||
self._tmp_order = None
|
||||
for line in fp:
|
||||
line = line.strip()
|
||||
if self._state == self.State.DATA:
|
||||
self._data(line)
|
||||
elif self._state == self.State.COUNT:
|
||||
self._count(line)
|
||||
elif self._state == self.State.HEADER:
|
||||
self._header(line)
|
||||
elif self._state == self.State.ENTRY:
|
||||
self._entry(line)
|
||||
if self._state != self.State.DATA:
|
||||
raise Exception(line)
|
||||
return self._result
|
||||
|
||||
def _data(self, line):
|
||||
if line == "\\data\\":
|
||||
self._state = self.State.COUNT
|
||||
self._tmp_model = Arpa()
|
||||
else:
|
||||
pass # skip comment line
|
||||
|
||||
def _count(self, line):
|
||||
match = self.re_count.match(line)
|
||||
if match:
|
||||
order = match.group(1)
|
||||
count = match.group(2)
|
||||
self._tmp_model.add_count(int(order), int(count))
|
||||
elif not line:
|
||||
self._state = self.State.HEADER # there are no counts
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
def _header(self, line):
|
||||
match = self.re_header.match(line)
|
||||
if match:
|
||||
self._state = self.State.ENTRY
|
||||
self._tmp_order = int(match.group(1))
|
||||
elif line == "\\end\\":
|
||||
self._result.append(self._tmp_model)
|
||||
self._state = self.State.DATA
|
||||
self._tmp_model = None
|
||||
self._tmp_order = None
|
||||
elif not line:
|
||||
pass # skip empty line
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
def _entry(self, line):
|
||||
match = self.re_entry.match(line)
|
||||
if match:
|
||||
p = self._float_or_int(match.group(1))
|
||||
ngram = tuple(match.group(4).split(" "))
|
||||
bo_match = match.group(7)
|
||||
bo = self._float_or_int(bo_match) if bo_match else None
|
||||
self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
|
||||
elif not line:
|
||||
self._state = self.State.HEADER # last entry
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
@staticmethod
|
||||
def _float_or_int(s):
|
||||
f = float(s)
|
||||
i = int(f)
|
||||
if str(i) == s: # don't drop trailing ".0"
|
||||
return i
|
||||
else:
|
||||
return f
|
||||
|
||||
def load(self, fp):
|
||||
"""Deserialize fp (a file-like object) to a Python object."""
|
||||
return self._parse(fp)
|
||||
|
||||
def loadf(self, path, encoding=None):
|
||||
"""Deserialize path (.arpa, .gz) to a Python object."""
|
||||
path = str(path)
|
||||
if path.endswith(".gz"):
|
||||
with gzip.open(path, mode="rt", encoding=encoding) as f:
|
||||
return self.load(f)
|
||||
else:
|
||||
with open(path, mode="rt", encoding=encoding) as f:
|
||||
return self.load(f)
|
||||
|
||||
def loads(self, s):
|
||||
"""Deserialize s (a str) to a Python object."""
|
||||
with StringIO(s) as f:
|
||||
return self.load(f)
|
||||
|
||||
def dump(self, obj, fp):
|
||||
"""Serialize obj to fp (a file-like object) in ARPA format."""
|
||||
obj.write(fp)
|
||||
|
||||
def dumpf(self, obj, path, encoding=None):
|
||||
"""Serialize obj to path in ARPA format (.arpa, .gz)."""
|
||||
path = str(path)
|
||||
if path.endswith(".gz"):
|
||||
with gzip.open(path, mode="wt", encoding=encoding) as f:
|
||||
return self.dump(obj, f)
|
||||
else:
|
||||
with open(path, mode="wt", encoding=encoding) as f:
|
||||
self.dump(obj, f)
|
||||
|
||||
def dumps(self, obj):
|
||||
"""Serialize obj to an ARPA formatted str."""
|
||||
with StringIO() as f:
|
||||
self.dump(obj, f)
|
||||
return f.getvalue()
|
||||
|
||||
|
||||
def add_log_p(prev_log_sum, log_p, base):
|
||||
return math.log(base**log_p + base**prev_log_sum, base)
|
||||
|
||||
|
||||
def compute_numerator_denominator(lm, h):
|
||||
log_sum_seen_h = -math.inf
|
||||
log_sum_seen_h_lower = -math.inf
|
||||
base = lm.base
|
||||
for w, log_p in lm._ngrams[len(h)][h].items():
|
||||
log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
|
||||
|
||||
ngram = h + (w,)
|
||||
log_p_lower = lm.log_p_raw(ngram[1:])
|
||||
log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
|
||||
|
||||
numerator = 1.0 - base**log_sum_seen_h
|
||||
denominator = 1.0 - base**log_sum_seen_h_lower
|
||||
return numerator, denominator
|
||||
|
||||
|
||||
def prune(lm, threshold, minorder):
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
|
||||
|
||||
for i in range(
|
||||
lm.order(), max(minorder - 1, 1), -1
|
||||
): # i is the order of the ngram (h, w)
|
||||
logging.info("processing %d-grams ..." % i)
|
||||
count_pruned_ngrams = 0
|
||||
|
||||
h_dict = lm._ngrams[i - 1]
|
||||
for h in list(h_dict.keys()):
|
||||
# old backoff weight, BOW(h)
|
||||
log_bow = lm._log_bo(h)
|
||||
if log_bow is None:
|
||||
log_bow = 0
|
||||
|
||||
# Compute numerator and denominator of the backoff weight,
|
||||
# so that we can quickly compute the BOW adjustment due to
|
||||
# leaving out one prob.
|
||||
numerator, denominator = compute_numerator_denominator(lm, h)
|
||||
|
||||
# assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
|
||||
|
||||
# Compute the marginal probability of the context, P(h)
|
||||
h_log_p = lm.log_joint_prob(h)
|
||||
|
||||
all_pruned = True
|
||||
pruned_w_set = set()
|
||||
|
||||
for w, log_p in h_dict[h].items():
|
||||
ngram = h + (w,)
|
||||
|
||||
# lower-order estimate for ngramProb, P(w|h')
|
||||
backoff_prob = lm.log_p_raw(ngram[1:])
|
||||
|
||||
# Compute BOW after removing ngram, BOW'(h)
|
||||
new_log_bow = math.log(
|
||||
numerator + lm.base**log_p, lm.base
|
||||
) - math.log(denominator + lm.base**backoff_prob, lm.base)
|
||||
|
||||
# Compute change in entropy due to removal of ngram
|
||||
delta_prob = backoff_prob + new_log_bow - log_p
|
||||
delta_entropy = -(lm.base**h_log_p) * (
|
||||
(lm.base**log_p) * delta_prob
|
||||
+ numerator * (new_log_bow - log_bow)
|
||||
)
|
||||
|
||||
# compute relative change in model (training set) perplexity
|
||||
perp_change = lm.base**delta_entropy - 1.0
|
||||
|
||||
pruned = threshold > 0 and perp_change < threshold
|
||||
|
||||
# Make sure we don't prune ngrams whose backoff nodes are needed
|
||||
if (
|
||||
pruned
|
||||
and len(ngram) in lm._ngrams
|
||||
and len(lm._ngrams[len(ngram)][ngram]) > 0
|
||||
):
|
||||
pruned = False
|
||||
|
||||
logging.debug(
|
||||
"CONTEXT "
|
||||
+ str(h)
|
||||
+ " WORD "
|
||||
+ w
|
||||
+ " CONTEXTPROB %f " % h_log_p
|
||||
+ " OLDPROB %f " % log_p
|
||||
+ " NEWPROB %f " % (backoff_prob + new_log_bow)
|
||||
+ " DELTA-H %f " % delta_entropy
|
||||
+ " DELTA-LOGP %f " % delta_prob
|
||||
+ " PPL-CHANGE %f " % perp_change
|
||||
+ " PRUNED "
|
||||
+ str(pruned)
|
||||
)
|
||||
|
||||
if pruned:
|
||||
pruned_w_set.add(w)
|
||||
count_pruned_ngrams += 1
|
||||
else:
|
||||
all_pruned = False
|
||||
|
||||
# If we removed all ngrams for this context we can
|
||||
# remove the context itself, but only if the present
|
||||
# context is not a prefix to a longer one.
|
||||
if all_pruned and len(pruned_w_set) == len(h_dict[h]):
|
||||
del h_dict[
|
||||
h
|
||||
] # this context h is no longer needed, as its ngram prob is stored at its own context h'
|
||||
elif len(pruned_w_set) > 0:
|
||||
# The pruning for this context h is actually done here
|
||||
old_context = lm.set_new_context(h)
|
||||
|
||||
for w, p_w in old_context.items():
|
||||
if w not in pruned_w_set:
|
||||
lm.add_entry(
|
||||
h + (w,), p_w
|
||||
) # the entry hw is stored at the context h
|
||||
|
||||
# We need to recompute the back-off weight, but
|
||||
# this can only be done after completing the pruning
|
||||
# of the lower-order ngrams.
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
|
||||
|
||||
logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
|
||||
|
||||
# recompute backoff weights
|
||||
for i in range(
|
||||
max(minorder - 1, 1) + 1, lm.order() + 1
|
||||
): # be careful of this order: from low- to high-order
|
||||
for h in lm._ngrams[i - 1]:
|
||||
numerator, denominator = compute_numerator_denominator(lm, h)
|
||||
new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
|
||||
lm._ngrams[len(h)][h].log_bo = new_log_bow
|
||||
|
||||
# update counts
|
||||
lm.update_counts()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def check_h_is_valid(lm, h):
|
||||
sum_under_h = sum(
|
||||
[lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
|
||||
)
|
||||
if abs(sum_under_h - 1.0) > 1e-6:
|
||||
logging.info("warning: %s %f" % (str(h), sum_under_h))
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def validate_lm(lm):
|
||||
# sanity check if the conditional probability sums to one under each context h
|
||||
for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
|
||||
logging.info("validating %d-grams ..." % i)
|
||||
h_dict = lm._ngrams[i - 1]
|
||||
for h in h_dict.keys():
|
||||
check_h_is_valid(lm, h)
|
||||
|
||||
|
||||
def compare_two_apras(path1, path2):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load an arpa file
|
||||
logging.info("Loading the arpa file from %s" % args.lm)
|
||||
parser = ArpaParser()
|
||||
models = parser.loadf(args.lm, encoding=default_encoding)
|
||||
lm = models[0] # ARPA files may contain several models.
|
||||
logging.info("Stats before pruning:")
|
||||
for i, cnt in lm.counts():
|
||||
logging.info("ngram %d=%d" % (i, cnt))
|
||||
|
||||
# prune it, the language model will be modified in-place
|
||||
logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
|
||||
prune(lm, args.threshold, args.minorder)
|
||||
|
||||
# validate_lm(lm)
|
||||
|
||||
# write the arpa language model to a file
|
||||
logging.info("Stats after pruning:")
|
||||
for i, cnt in lm.counts():
|
||||
logging.info("ngram %d=%d" % (i, cnt))
|
||||
logging.info("Saving the pruned arpa file to %s" % args.write_lm)
|
||||
parser.dumpf(lm, args.write_lm, encoding=default_encoding)
|
||||
logging.info("Done.")
|
1
icefall/shared/ngram_entropy_pruning.py
Symbolic link
1
icefall/shared/ngram_entropy_pruning.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/shared/ngram_entropy_pruning.py
|
@ -1,97 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
||||
# Arnab Ghoshal, Karel Vesely
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parse command-line options.
|
||||
# To be sourced by another script (as in ". parse_options.sh").
|
||||
# Option format is: --option-name arg
|
||||
# and shell variable "option_name" gets set to value "arg."
|
||||
# The exception is --help, which takes no arguments, but prints the
|
||||
# $help_message variable (if defined).
|
||||
|
||||
|
||||
###
|
||||
### The --config file options have lower priority to command line
|
||||
### options, so we need to import them first...
|
||||
###
|
||||
|
||||
# Now import all the configs specified by command-line, in left-to-right order
|
||||
for ((argpos=1; argpos<$#; argpos++)); do
|
||||
if [ "${!argpos}" == "--config" ]; then
|
||||
argpos_plus1=$((argpos+1))
|
||||
config=${!argpos_plus1}
|
||||
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
||||
. $config # source the config file.
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###
|
||||
### Now we process the command line options
|
||||
###
|
||||
while true; do
|
||||
[ -z "${1:-}" ] && break; # break if there are no arguments
|
||||
case "$1" in
|
||||
# If the enclosing script is called with --help option, print the help
|
||||
# message and exit. Scripts should put help messages in $help_message
|
||||
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
||||
else printf "$help_message\n" 1>&2 ; fi;
|
||||
exit 0 ;;
|
||||
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
||||
exit 1 ;;
|
||||
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
||||
# then work out the variable name as $name, which will equal "foo_bar".
|
||||
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
||||
# Next we test whether the variable in question is undefned-- if so it's
|
||||
# an invalid option and we die. Note: $0 evaluates to the name of the
|
||||
# enclosing script.
|
||||
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
||||
# is undefined. We then have to wrap this test inside "eval" because
|
||||
# foo_bar is itself inside a variable ($name).
|
||||
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
||||
|
||||
oldval="`eval echo \\$$name`";
|
||||
# Work out whether we seem to be expecting a Boolean argument.
|
||||
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
||||
was_bool=true;
|
||||
else
|
||||
was_bool=false;
|
||||
fi
|
||||
|
||||
# Set the variable to the right value-- the escaped quotes make it work if
|
||||
# the option had spaces, like --cmd "queue.pl -sync y"
|
||||
eval $name=\"$2\";
|
||||
|
||||
# Check that Boolean-valued arguments are really Boolean.
|
||||
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
||||
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
shift 2;
|
||||
;;
|
||||
*) break;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# Check for an empty argument to the --cmd option, which can easily occur as a
|
||||
# result of scripting errors.
|
||||
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
||||
|
||||
|
||||
true; # so this script returns exit code 0.
|
1
icefall/shared/parse_options.sh
Symbolic link
1
icefall/shared/parse_options.sh
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/shared/parse_options.sh
|
Loading…
x
Reference in New Issue
Block a user