IWSLT-Ta ASR/ST (#1362)

This is a pull request for Dialectal IWSLT-Tunisian 2022 shared task https://iwslt.org/2022/dialect ASR and ST recipes.
This commit is contained in:
Amir Hussein 2025-09-21 21:58:00 -04:00 committed by GitHub
parent 855536d355
commit 729a5ba3ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
127 changed files with 28855 additions and 1 deletions

View File

@ -0,0 +1,26 @@
# IWSLT_Ta
The IWSLT Tunisian dataset is a 3-way parallel dataset consisting of approximately 160 hours
and 200,000 lines of aligned audio, Tunisian transcripts, and English translations. This dataset
comprises conversational telephone speech recorded at a sampling rate of 8kHz. The train, dev,
and test1 splits of the iwslt2022 shared task correspond to catalog number LDC2022E01. Please
note that access to this data requires an LDC subscription from your institution.To obtain this
dataset, you should download the predefined splits by running the following command:
git clone https://github.com/kevinduh/iwslt22-dialect.git. For more detailed information about
the shared task, please refer to the task paper available at this link:
https://aclanthology.org/2022.iwslt-1.10/.
## Stateless Pruned Transducer Performance Record (after 20 epochs)
| Decoding method | dev WER | test WER | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 47.6 | 51.2 | --epoch 20, --avg 10 |
## Zipformer Performance Record (after 20 epochs)
| Decoding method | dev WER | test WER | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 40.8 | 44.4 | --epoch 20, --avg 10 |
See [RESULTS](RESULTS.md) for details.

View File

@ -0,0 +1,110 @@
# Results
### IWSLT Tunisian training results (Stateless Pruned Transducer)
#### 2023-06-01
| Decoding method | dev WER | test WER | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 47.6 | 51.2 | --epoch 20, --avg 13 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--world-size 4 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--max-duration 300 \
--num-buckets 50
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/yBijWJSPSGuBqMwTZ509lA/
The decoding command is:
```
for method in modified_beam_search; do
./pruned_transducer_stateless5/decode.py \
--epoch 15 \
--beam-size 20 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 400 \
--decoding-method modified_beam_search \
--max-sym-per-frame 1 \
--num-encoder-layers 12 \
--dim-feedforward 1024 \
--nhead 8 \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--use-averaged-model true
done
```
### IWSLT Tunisian training results (Zipformer)
#### 2023-06-01
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/AmirHussein/zipformer-iwslt22-Ta>
| Decoding method | dev WER | test WER | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 40.8 | 44.1 | --epoch 20, --avg 13 |
To reproduce the above result, use the following commands for training:
# Note: the model was trained on V-100 32GB GPU
```
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train.py \
--world-size 4 \
--num-epochs 20 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,384,512,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--max-duration 800 \
--prune-range 10
```
The decoding command is:
```
for method in modified_beam_search; do
./zipformer/decode.py \
--epoch 20 \
--beam-size 20 \
--avg 13 \
--exp-dir ./zipformer/exp\
--max-duration 800 \
--decoding-method $method \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,384,512,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--use-averaged-model true
done
```

View File

View File

@ -0,0 +1 @@
../../ST/local/cer.py

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the MGB2 dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import argparse
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-splits",
type=int,
default=20,
help="Number of splits for the train set.",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the train set split.",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop index of the train set split.",
)
parser.add_argument(
"--test",
action="store_true",
help="If set, only compute features for the dev and val set.",
)
return parser.parse_args()
def compute_fbank_gpu(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(parents=True, exist_ok=True)
num_jobs = os.cpu_count()
num_mel_bins = 80
sampling_rate = 16000
sr = 16000
dataset_parts = ("dev", "test1") if args.test else ("train", "test1", "dev")
manifests = read_manifests_if_cached(
prefix="iwslt-ta", dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info("About to split cuts into smaller chunks.")
if sr != None:
logging.info(f"Resampling to {sr}")
cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.to_eager()
chunk_size = len(cut_set) // args.num_splits
cut_sets = cut_set.split_lazy(
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
chunk_size=chunk_size,)
start = args.start
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
num_digits = len(str(args.num_splits))
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
logging.info(f"Processing train split {i}")
cs = cut_sets[i].compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_train_{idx}",
batch_duration=1000,
num_workers=8,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cs.to_file(cuts_train_idx_path)
else:
logging.info(f"Processing {partition}")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{partition}",
batch_duration=1000,
num_workers=10,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_gpu(args)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1 @@
../../ST/local/cuts_validate.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/display_manifest_statistics.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/filter_cuts.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/generate_unique_lexicon.py

View File

@ -0,0 +1 @@
../../ST/local/prep_lexicon.sh

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1 @@
/exp/ahussein/tmp/icefall/egs/iwslt22_ta/ST/local/prepare_transcripts.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/test_prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_manifest.py

165
egs/iwslt22_ta/ASR/prepare.sh Executable file
View File

@ -0,0 +1,165 @@
#!/usr/bin/env bash
# Copyright 2023 Johns Hopkins University (Amir Hussein)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
set -eou pipefail
nj=20
stage=1
stop_stage=4
# We assume dl_dir (download dir) contains the following
# directories and files.
#
# - $dl_dir/iwslt_ta
#
# You can download the data from
#
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
#
# Note: iwslt_ta is not available for direct
# download, "Download IWSLT Tunisian from LDC LDC2022E01. This script assumes you prepared the stm files"
#"Check the instructions to prepare the stm files from the raw data here https://github.com/kevinduh/iwslt22-dialect"
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
1000
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/iwslt_ta,
# you can create a symlink
#
# ln -sfv /path/to/iwslt_ta $dl_dir/iwslt_ta
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
fbank=data/fbank
manifests=data/manifests
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare iwslt manifest"
# We assume that you have downloaded the iwslt_ta corpus to $dl_dir/iwslt_ta
# Also git clone https://github.com/kevinduh/iwslt22-dialect
if [ ! -d "iwslt22-dialect" ]; then
echo "Splits directory (iwslt22-dialect) does not exist"
echo "Run git clone https://github.com/kevinduh/iwslt22-dialect"
exit 1
fi
manifests=data/manifests
mkdir -p $manifests
lhotse prepare iwslt_ta $dl_dir/iwslt_ta iwslt22-dialect data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p $manifests
lhotse prepare musan $dl_dir/musan $manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank features"
mkdir -p ${fbank}
python local/compute_fbank_gpu.py --num-splits 20
log "Combine features from train splits (may take ~1h)"
if [ ! -f $manifests/cuts_train.jsonl.gz ]; then
pieces=$(find $manifests -name "cuts_train_[0-9]*.jsonl.gz")
lhotse combine $pieces $manifests/cuts_train.jsonl.gz
fi
gunzip -c $manifests/cuts_train.jsonl.gz | shuf | gzip -c > ${fbank}/cuts_train_shuf.jsonl.gz
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p ${fbank}
./local/compute_fbank_musan.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir=data/lang_phone
if [ ! -f download/lm/train/transcript_words.txt ]; then
# export train text file to build grapheme lexicon
log "Creating transcripts in download/lm/train from lhotse cuts"
mkdir -p download/lm/train
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --langdir download/lm/train
fi
mkdir -p $lang_dir
log "Prepare lexicon"
./local/prep_lexicon.sh download/lm/train
python local/prepare_lexicon.py $dl_dir/lm/train/words.txt $dl_dir/lm/train/lexicon.txt
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/train/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p ${lang_dir}
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone_src/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate Tunisian text for BPE training from data/fbank/cuts_train_shuf.jsonl.gz"
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --langdir ${ang_dir}
fi
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi
done
done
fi

View File

@ -0,0 +1,396 @@
# Copyright 2023 Amir Hussein
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2Speech2textTranslationDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class IWSLTDialectSTDataModule:
"""
DataModule for k2 ST experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2Speech2textTranslationDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=8,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2Speech2textTranslationDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,960 @@
#!/usr/bin/env python3
# Copyright 2023 Johns Hopkins (authors: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(3) modified beam search
./pruned_transducer_stateless5/decode_asr.py \
--epoch 15 \
--beam-size 20 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp_asr \
--max-duration 400 \
--decoding-method modified_beam_search \
--max-sym-per-frame 1 \
--num-encoder-layers 12 \
--dim-feedforward 1024 \
--nhead 8 \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--use-averaged-model true
"""
import argparse
import logging
import math
import pdb
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from lhotse.qa import validate_cut
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import IWSLTDialectSTDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_ta_1000/bpe.model",
help="Path to source data BPE model",
)
parser.add_argument(
"--bpe-tgt-model",
type=str,
default="data/lang_bpe_en_1000/bpe.model",
help="Path to target data BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/ang_bpe_ta_1000",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--lang-tgt-dir",
type=Path,
default="data/lang_bpe_en_1000",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens)
hyps = []
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def remove_short_and_long_utt(c):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 0.5 or c.duration > 30.0:
#logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
#)
return False
if c.supervisions == []:
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsamplin
return True
# def remove_seg(c):
# if c.supervisions[0].id != 'fla_0102_1_0B_00107':
# return True
# else:
# return False
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
logging.info(f"Decoding {batch_idx}-th batch")
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir /
f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir /
f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
IWSLTDialectSTDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
rnn_lm_model = None
rnn_lm_scale = params.rnn_lm_scale
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
iwslt_ta = IWSLTDialectSTDataModule(args)
test_cuts = iwslt_ta.test_cuts()
dev_cuts = iwslt_ta.dev_cuts()
# lev_test_cuts = lev_test_cuts.filter(remove_short_and_long_utt)
# # lev_test_cuts = lev_test_cuts.filter(remove_seg)
# gulf_test_cuts = gulf_test_cuts.filter(remove_short_and_long_utt)
# egy_test_cuts = egy_test_cuts.filter(remove_short_and_long_utt)
# egy_h5_cuts = egy_sup_cuts.filter(remove_short_and_long_utt)
# egy_sup_cuts = egy_h5_cuts.filter(remove_short_and_long_utt)
test_dl = iwslt_ta.test_dataloaders(test_cuts)
dev_dl = iwslt_ta.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_all_dl = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_all_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../../ST/zipformer/export.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
./pruned_transducer_stateless5/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py
"""
from train import get_params, get_transducer_model
def test_model_1():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 24
params.dim_feedforward = 1536 # 384 * 4
params.encoder_dim = 384
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
def test_model_M():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 18
params.dim_feedforward = 1024
params.encoder_dim = 256
params.nhead = 4
params.decoder_dim = 512
params.joiner_dim = 512
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
# test_model_1()
test_model_M()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

1
egs/iwslt22_ta/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1,396 @@
# Copyright 2023 Amir Hussein
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2Speech2textTranslationDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class IWSLTDialectSTDataModule:
"""
DataModule for k2 ST experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2Speech2textTranslationDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=8,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2Speech2textTranslationDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/beam_search.py

View File

@ -0,0 +1,852 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Johns Hopkins University (Author: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import IWSLTDialectSTDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train_asr import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_ta_1000/bpe.model",
help="Path to source data BPE model",
)
parser.add_argument(
"--bpe-tgt-model",
type=str,
default="data/lang_bpe_en_1000/bpe.model",
help="Path to target data BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/ang_bpe_ta_1000",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--lang-tgt-dir",
type=Path,
default="data/lang_bpe_en_1000",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
use_hat=True,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
IWSLTDialectSTDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
# use predefined parameters that were used during the training
# params.num_encoder_layers = "2,2,2,2,2,2"
# params.feedforward_dim = "256,512,768,1024,768,512"
# params.encoder_dim = "128,256,256,512,256,256"
# params.encoder_unmasked_dim = "64,128,128,256,128,128"
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
iwslt_ta = IWSLTDialectSTDataModule(args)
test_cuts = iwslt_ta.test_cuts()
dev_cuts = iwslt_ta.dev_cuts()
test_dl = iwslt_ta.test_dataloaders(test_cuts)
dev_dl = iwslt_ta.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_all_dl = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_all_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../ST/zipformer/decoder.py

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1 @@
../../ST/zipformer/export.py

View File

@ -0,0 +1 @@
../../ST/zipformer/generate_averaged_model.py

View File

@ -0,0 +1 @@
../../ST/zipformer/jit_pretrained.py

View File

@ -0,0 +1 @@
../../ST/zipformer/jit_pretrained_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../ST/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../ST/zipformer/pretrained.py

View File

@ -0,0 +1,176 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: ./zipformer/profile.py
"""
import argparse
import logging
import sentencepiece as spm
import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
from scaling import BiasNorm
from train import (
get_encoder_embed,
get_encoder_model,
get_joiner_model,
add_model_arguments,
get_params,
)
from zipformer import BypassModule
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
add_model_arguments(parser)
return parser
def _bias_norm_flops_compute(module, input, output):
assert len(input) == 1, len(input)
# estimate as layer_norm, see icefall/profiler.py
flops = input[0].numel() * 5
module.__flops__ += int(flops)
def _swoosh_module_flops_compute(module, input, output):
# For SwooshL and SwooshR modules
assert len(input) == 1, len(input)
# estimate as swish/silu, see icefall/profiler.py
flops = input[0].numel()
module.__flops__ += int(flops)
def _bypass_module_flops_compute(module, input, output):
# For Bypass module
assert len(input) == 2, len(input)
flops = input[0].numel() * 2
module.__flops__ += int(flops)
MODULE_HOOK_MAPPING = {
BiasNorm: _bias_norm_flops_compute,
BypassModule: _bypass_module_flops_compute,
}
class Model(nn.Module):
"""A Wrapper for encoder, encoder_embed, and encoder_proj"""
def __init__(
self,
encoder: nn.Module,
encoder_embed: nn.Module,
encoder_proj: nn.Module,
) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(
self, feature: Tensor, feature_lens: Tensor
) -> Tuple[Tensor, Tensor]:
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
logits = self.encoder_proj(encoder_out)
return logits, encoder_out_lens
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
# We only profile the encoder part
model = Model(
encoder=get_encoder_model(params),
encoder_embed=get_encoder_embed(params),
encoder_proj=get_joiner_model(params).encoder_proj,
)
model.eval()
model.to(device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# for 30-second input
B, T, D = 1, 3000, 80
feature = torch.ones(B, T, D, dtype=torch.float32).to(device)
feature_lens = torch.full((B,), T, dtype=torch.int64).to(device)
flops, params = get_model_profile(
model=model,
args=(feature, feature_lens),
module_hoop_mapping=MODULE_HOOK_MAPPING,
)
logging.info(f"For the encoder part, params: {params}, flops: {flops}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_decode.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -0,0 +1,28 @@
# IWSLT_Ta
The IWSLT Tunisian dataset is a 3-way parallel dataset consisting of approximately 160 hours
and 200,000 lines of aligned audio, Tunisian transcripts, and English translations. This dataset
comprises conversational telephone speech recorded at a sampling rate of 8kHz. The train, dev,
and test1 splits of the iwslt2022 shared task correspond to catalog number LDC2022E01. Please
note that access to this data requires an LDC subscription from your institution.To obtain this
dataset, you should download the predefined splits by running the following command:
git clone https://github.com/kevinduh/iwslt22-dialect.git. For more detailed information about
the shared task, please refer to the task paper available at this link:
https://aclanthology.org/2022.iwslt-1.10/.
## Stateless Pruned Transducer Performance Record (after 20 epochs)
| Decoding method | dev Bleu | test Bleu | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 13, beam(10), pruned range 5 |
## Zipformer Performance Record (after 20 epochs)
| Decoding method | dev Bleu | test Bleu | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 13, beam(10),pruned range 5 |
| modified beam search | 15.5 | 13 | --epoch 20, --avg 13, beam(20),pruned range 5 |
| modified beam search | 17.9 | 14.9 | --epoch 20, --avg 13, beam(20), pruned range 10 |
See [RESULTS](/egs/iwslt_ta/ST/RESULTS.md) for details.

View File

@ -0,0 +1,125 @@
# Results
### IWSLT Tunisian training results (Stateless Pruned Transducer)
#### 2023-06-01
| Decoding method | dev Bleu | test Bleu | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 10, beam(10), pruned range 5 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--world-size 4 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--max-duration 300 \
--bucketing-sampler 1\
--num-buckets 50
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/YnzQNCVDSxCvP1onrCzg9A/
The decoding command is:
```
for method in modified_beam_search; do
for epoch in 15 20; do
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--beam-size 20 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp_st \
--max-duration 300 \
--decoding-method $method \
--max-sym-per-frame 1 \
--num-encoder-layers 12 \
--dim-feedforward 1024 \
--nhead 8 \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--use-averaged-model true
done
done
```
### IWSLT Tunisian training results (Zipformer)
#### 2023-06-01
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/AmirHussein/zipformer-iwslt22-Ta>
| Decoding method | dev Bleu | test Bleu | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 10, beam(10),pruned range 5 |
| modified beam search | 15.5 | 13 | --epoch 20, --avg 10, beam(20),pruned range 5 |
| modified beam search | 18.2 | 14.8 | --epoch 20, --avg 10, beam(20), pruned range 10 |
To reproduce the above result, use the following commands for training:
# Note: the model was trained on V-100 32GB GPU
# ST medium model 42.5M prune-range 10
```
./zipformer/train.py \
--world-size 4 \
--num-epochs 25 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-st-medium \
--causal 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,384,512,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--max-duration 800 \
--prune-range 10 \
--warm-step 5000 \
--lr-epochs 8 \
--base-lr 0.055 \
--use-hat False
```
The decoding command is:
```
for method in modified_beam_search; do
for epoch in 15 20; do
./zipformer/decode.py \
--epoch $epoch \
--beam-size 20 \
--avg 10 \
--exp-dir ./zipformer/exp-st-medium-prun10 \
--max-duration 800 \
--decoding-method $method \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,384,512,384,256 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--context-size 2 \
--use-averaged-model true \
--use-hat False
done
done
```

View File

View File

@ -0,0 +1,58 @@
#!/usr/bin/python
# Copyright 2023 Johns Hopkins University (Amir Hussein)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
This script computes CER for the decodings generated by icefall recipe
"""
import argparse
import jiwer
import os
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dec-file",
type=str,
help="file with decoded text"
)
return parser
def cer_(file):
hyp = []
ref = []
cer_results = 0
ref_lens = 0
with open(file, 'r', encoding='utf-8') as dec:
for line in dec:
id, target = line.split('\t')
id = id[0:-2]
target, txt = target.split("=")
if target == 'ref':
words = txt.strip().strip('[]').split(', ')
word_list = [word.strip("'") for word in words]
ref.append(" ".join(word_list))
elif target == 'hyp':
words = txt.strip().strip('[]').split(', ')
word_list = [word.strip("'") for word in words]
hyp.append(" ".join(word_list))
for h, r in zip(hyp, ref):
#breakpoint()
cer_results += (jiwer.cer(r, h)*len(r))
ref_lens += len(r)
print(os.path.basename(file))
print(cer_results/ref_lens)
def main():
parse = get_args()
args = parse.parse_args()
cer_(args.dec_file)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,171 @@
#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the MGB2 dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import argparse
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-splits",
type=int,
default=20,
help="Number of splits for the train set.",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the train set split.",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop index of the train set split.",
)
parser.add_argument(
"--test",
action="store_true",
help="If set, only compute features for the dev and val set.",
)
return parser.parse_args()
def compute_fbank_gpu(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = os.cpu_count()
num_mel_bins = 80
sampling_rate = 16000
sr = 16000
dataset_parts = (
"train",
"test1",
"dev",
)
manifests = read_manifests_if_cached(
prefix="iwslt-ta", dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info("About to split cuts into smaller chunks.")
if sr != None:
logging.info(f"Resampling to {sr}")
cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.to_eager()
chunk_size = len(cut_set) // args.num_splits
cut_sets = cut_set.split_lazy(
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
chunk_size=chunk_size,)
start = args.start
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
num_digits = len(str(args.num_splits))
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
logging.info(f"Processing train split {i}")
cs = cut_sets[i].compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_train_{idx}",
batch_duration=1000,
num_workers=8,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cs.to_file(cuts_train_idx_path)
else:
logging.info(f"Processing {partition}")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{partition}",
batch_duration=1000,
num_workers=10,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_gpu(args)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
"""
Convert a transcript file containing words to a corpus file containing tokens
for LM training with the help of a lexicon.
If the lexicon contains phones, the resulting LM will be a phone LM; If the
lexicon contains word pieces, the resulting LM will be a word piece LM.
If a word has multiple pronunciations, the one that appears first in the lexicon
is kept; others are removed.
If the input transcript is:
hello zoo world hello
world zoo
foo zoo world hellO
and if the lexicon is
<UNK> SPN
hello h e l l o 2
hello h e l l o
world w o r l d
zoo z o o
Then the output is
h e l l o 2 z o o w o r l d h e l l o 2
w o r l d z o o
SPN z o o w o r l d SPN
"""
import argparse
from pathlib import Path
from typing import Dict, List
from generate_unique_lexicon import filter_multiple_pronunications
from icefall.lexicon import read_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcript",
type=str,
help="The input transcript file."
"We assume that the transcript file consists of "
"lines. Each line consists of space separated words.",
)
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
parser.add_argument(
"--oov", type=str, default="<UNK>", help="The OOV word."
)
return parser.parse_args()
def process_line(
lexicon: Dict[str, List[str]], line: str, oov_token: str
) -> None:
"""
Args:
lexicon:
A dict containing pronunciations. Its keys are words and values
are pronunciations (i.e., tokens).
line:
A line of transcript consisting of space(s) separated words.
oov_token:
The pronunciation of the oov word if a word in `line` is not present
in the lexicon.
Returns:
Return None.
"""
s = ""
words = line.strip().split()
for i, w in enumerate(words):
tokens = lexicon.get(w, oov_token)
s += " ".join(tokens)
s += " "
print(s.strip())
def main():
args = get_args()
assert Path(args.lexicon).is_file()
assert Path(args.transcript).is_file()
assert len(args.oov) > 0
# Only the first pronunciation of a word is kept
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
lexicon = dict(lexicon)
assert args.oov in lexicon
oov_token = lexicon[args.oov]
with open(args.transcript) as f:
for line in f:
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,107 @@
#!/usr/bin/python
# Copyright 2023 Johns Hopkins University (Amir Hussein)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
This script helps validating the prepared manifests (recordings, supervisions)
and CutSets
"""
from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse
import logging
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sup",
type=str,
default="",
help="Supervisions file",
)
parser.add_argument(
"--rec",
type=str,
default="",
help="Recordings file",
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--savecut",
type=str,
default="",
help="name of the cutset to be saved",
)
return parser
def valid_asr(cut):
tol = 2e-3
i=0
total_dur = 0
for c in cut:
if c.supervisions != []:
if c.supervisions[0].end > c.duration + tol:
logging.info(f"Supervision beyond the cut. Cut number: {i}")
total_dur += c.duration
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
elif c.supervisions[0].start < -tol:
logging.info(f"Supervision starts before the cut. Cut number: {i}")
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
else:
continue
else:
logging.info("Empty supervision")
logging.info(f"id: {c.id}")
i += 1
logging.info(f"filtered duration: {total_dur}")
def main():
parser = get_parser()
args = parser.parse_args()
if args.cut != "":
cuts = CutSet.from_file(args.cut)
else:
recordings = RecordingSet.from_file(args.rec)
supervisions = SupervisionSet.from_file(args.sup)
logging.info("Example from supervisions:")
logging.info(supervisions[0])
logging.info("Example from recordings")
print(recordings[0])
logging.info("Fixing manifests")
recordings, supervisions = fix_manifests(recordings, supervisions)
logging.info("Validating manifests")
validate_recordings_and_supervisions(recordings, supervisions)
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
logging.info("Example from cut:")
print(cuts[100])
cuts.describe()
logging.info("Validating manifests for ASR")
valid_asr(cuts)
if args.savecut != "":
cuts.to_file(args.savecut)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/display_manifest_statistics.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/generate_unique_lexicon.py

View File

@ -0,0 +1,21 @@
#!/usr/bin/env bash
# Copyright 2022 QCRI (author: Amir Hussein)
# Apache 2.0
# This script prepares the graphemic lexicon.
dir=data/local/dict
stage=0
lang_dir_src=$1
lang_dir_tgt=$2
cat $lang_dir_src/transcript_words.txt | tr -s " " "\n" | sort -u > $lang_dir_src/uniq_words
cat $lang_dir_tgt/transcript_words.txt | tr -s " " "\n" | sort -u > $lang_dir_tgt/uniq_words
echo "$0: processing lexicon text and creating lexicon... $(date)."
# remove vowels and rare alef wasla
cat $lang_dir_src/uniq_words | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir_src/words.txt
cat $lang_dir_tgt/uniq_words | sed -r '/^\s*$/d' | sort -u > $lang_dir_tgt/words.txt
echo "$0: Lexicon preparation succeeded"

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright 2023 Johns Hopkins University (Amir Hussein)
# Apache 2.0
# This script prepares givel a column of words lexicon.
import argparse
def get_args():
parser = argparse.ArgumentParser(
description="""Creates the list of characters and words in lexicon"""
)
parser.add_argument("input", type=str, help="""Input list of words file""")
parser.add_argument("output", type=str, help="""output graphemic lexicon""")
args = parser.parse_args()
return args
def main():
lex = {}
args = get_args()
with open(args.input, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
characters = list(line)
characters = " ".join(
["V" if char == "*" else char for char in characters]
)
lex[line] = characters
with open(args.output, "w", encoding="utf-8") as fp:
for key in sorted(lex):
fp.write(key + " " + lex[key] + "\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,66 @@
# Copyright 2023 Johns Hopkins University (Amir Hussein)
#!/usr/bin/python
"""
This script prepares transcript_words.txt from cutset
"""
from lhotse import CutSet
import argparse
import logging
import pdb
from pathlib import Path
import os
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--src-langdir",
type=str,
default="",
help="name of the source lang-dir",
)
parser.add_argument(
"--tgt-langdir",
type=str,
default=None,
help="name of the target lang-dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
logging.info("Reading the cuts")
cuts = CutSet.from_file(args.cut)
if args.tgt_langdir != None:
logging.info("Target dir is not None")
langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)]
else:
langdirs = [Path(args.src_langdir)]
for langdir in langdirs:
if not os.path.exists(langdir):
os.makedirs(langdir)
with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt:
for c in cuts:
src_txt = c.supervisions[0].text
tgt_txt = c.supervisions[0].custom['translated_text']['eng']
src.write(src_txt + '\n')
tgt.write(tgt_txt + '\n')
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/test_prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1,243 @@
#!/usr/bin/env python
# Copyright 2020 Kanari AI (Amir Hussein)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import pdb
import numpy as np
import pandas as pd
import re
import string
import argparse
import sys
import os
import pyarabic.number as number
from pyarabic import araby
_unicode = u"\u0622\u0624\u0626\u0628\u062a\u062c\u06af\u062e\u0630\u0632\u0634\u0636\u0638\u063a\u0640\u0642\u0644\u0646\u0648\u064a\u064c\u064e\u0650\u0652\u0670\u067e\u0686\u0621\u0623\u0625\u06a4\u0627\u0629\u062b\u062d\u062f\u0631\u0633\u0635\u0637\u0639\u0641\u0643\u0645\u0647\u0649\u064b\u064d\u064f\u0651\u0671"
_buckwalter = u"|&}btjGx*z$DZg_qlnwyNaio`PJ'><VApvHdrsSTEfkmhYFKu~{"
_backwardMap = {ord(b):a for a,b in zip(_buckwalter, _unicode)}
# def number2text(anumber):
# """
# Convert number to arabic words, for example convert 25 --> خمسة و عشرون
# Example:
# >>> number2text(523)
# خمسمئة وثلاث وعشرون
# @param anumber: input number
# @type anumber: int
# @return: number words
# @rtype: unicode
# """
# # test if the given type is numeric(float or int
# # if ok, convert it to string
# if type(anumber) is int or type(anumber) is float:
# anumber = str(anumber)
# # if the given type is str/unicode, test if it's a valid number
# elif type(anumber) is str or type(anumber) is unicode:
# try:
# a = float(anumber);
# except ValueError:
# return u"صفر"
# # if the given number not a valid return 0
# else:
# return u"صفر"
# arbn = number.ArNumbers()
# arbn.set_feminine(2)
# return arbn.int2str(anumber)
# return total
def fromBuckWalter(s):
return s.translate(_backwardMap)
def read_tsv(data_file):
text_data = list()
infile = open(data_file, encoding='utf-8')
for line in infile:
if not line.strip():
continue
text= line.split('\t')
text_data.append(text)
return text_data
words_to_remove =['#غير_واضح', '#تلعثم', '', '', '#ال', '#بال', '#وال', ' #وب', '###تداخل', 'FRN','RFN',\
' #سي', '#يي', '#هـ' , '#لل', '#بم', '#الش', '#آآ', ' #يت', '#وو', \
'#ومش', '#ول', '#وسي', '#غير_معروف','#العا','#مطا', '#محم' ,' #ماث','#متطو', ' #نشا', '#أأ', '#آآآ', ' #استج']
def normalizeArabic(text):
# text = re.sub("[إأٱآا]", "ا", text)
# text = re.sub("ى", "ي", text)
# text = re.sub("ة", "ه", text)
# text = re.sub("ئ", "ء", text)
# text = re.sub("ؤ", "ء", text)
text = re.sub(r"(ه){2,}", "ههه", text)
text = re.sub(r"(أ){2,}", "ا", text)
text = re.sub(r"(ا){2,}", "ا", text)
text = re.sub(r"(آ){2,}", "ا", text)
text = re.sub(r"(ص){2,}", "ص", text)
text = re.sub(r"(و){2,}", "و", text)
return text
def remove_special_words(text):
# remove special words from the transcription text
for word in words_to_remove:
if word in text:
text = text.replace(word, '')
return text
def remove_hashes(text):
return re.sub(r'#?', '', text)
def remove_english_characters(text):
return re.sub(r'[^\u0600-\u06FF0-9\s]+', '', text)
def remove_diacritics(text):
return re.sub(r'[\u064B-\u0652\u06D4\u0670\u0674\u06D5-\u06ED]+', '', text)
def remove_punctuations(text):
""" This function removes all punctuations except the verbatim """
arabic_punctuations = '''﴿`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ'''
english_punctuations = string.punctuation
all_punctuations = set(arabic_punctuations + english_punctuations)-{'@','%','.'} # remove all non verbatim punctuations
for p in all_punctuations:
if p in text:
text = text.replace(p, '')
text = re.sub('\s+\.','',text) # keep only the "." that is part of a word: marsad@aljazeera.net . => marsad@aljazeera.net
return text
def remove_extra_space(text):
text = text.lower()
text = re.sub('\s+', ' ', text)
text = re.sub('\s+\.\s+', '.', text)
return text
def remove_dot(text):
words = text.split()
res = []
for word in words:
if word.replace('.','').isnumeric(): # remove the dot if it is not part of a number
res.append(word)
else:
word = re.sub('\s+\.','',word)
res.append(word)
return " ".join(res)
def east_to_west_num(text):
eastern_to_western = {"٠":"0","١":"1","٢":"2","٣":"3","٤":"4","٥":"5","٦":"6","٧":"7","٨":"8","٩":"9","٪":"%","_":" ","ڤ":"ف","|":" "}
trans_string = str.maketrans(eastern_to_western)
return text.translate(trans_string)
def remove_repeating_char(text):
return re.sub(r'(.)\1+', r'\1', text)
def remove_single_char_word(text):
"""
Remove single character word from text
Example: I am in a a home for two years => am in home for two years
Args:
text (str): text
Returns:
(str): text with single char removed
"""
words = text.split()
filter_words = [word for word in words if len(word) > 1 or word.isnumeric()]
return " ".join(filter_words)
def seperate_english_characters(text):
text = text.lower()
res = re.findall(r'[a-z]+', text) # search for english
for match in res:
if match not in {'.',' '}:
text = re.sub(match, " "+ match+ " ",text)
text = re.sub('\s+', ' ', text)
return text
def digit2num(text, dig2num=False):
""" This function is used to clean numbers"""
# search for numbers with spaces
# 100 . 000 => 100.000
res = re.search('[0-9]+\s\.\s[0-9]+', text)
if res != None:
t = re.sub(r'\s', '', res[0])
text = re.sub(res[0], t, text)
# seperate numbers glued with words
# 3أشهر => 3 أشهر
# من10الى15 => من 10 الى 15
# pdb.set_trace()
res = re.findall(r'[^\u0600-\u06FF\%\@a-z]+', text) # search for digits
for match in res:
if match not in {'.',' '}:
text = re.sub(match, " "+ match+ " ",text)
text = re.sub('\s+', ' ', text)
# transliterate numbers to digits
# 13 => ثلاثة عشر
if dig2num == True:
words = araby.tokenize(text)
for i in range(len(words)):
digit = re.sub(r'[\u0600-\u06FF]+', '', words[i])
if digit.isnumeric():
sub_word = re.sub(r'[^\u0600-\u06FF]+', '', words[i])
if number.number2text(digit) != 'صفر':
words[i] = sub_word + number.number2text(digit)
else:
pass
return " ".join(words)
else:
return text
def data_cleaning(text):
# text = remove_special_words(text)
text = remove_punctuations(text)
text = remove_single_char_word(text)
text = remove_diacritics(text)
text = seperate_english_characters(text)
text = remove_extra_space(text)
text = remove_dot(text)
text = east_to_west_num(text)
text = digit2num(text, True)
text = normalizeArabic(text)
#text = re.sub(r'#\w{1,2}\b', '', text) # text = re.sub(r'#\w{1,3}\b', '', text)
#text = remove_hashes(text)
#text = normalizeArabic(text)
return text
def main():
input_file = sys.argv[1] # input transcription file with format <id> <text>
to_BW = str(sys.argv[2]) # transform to BW True|False
output_file=sys.argv[3] # output file name
data = read_tsv(input_file)
new_data = []
for i in range(len(data)):
tokens = data[i][0].split()
tokens[1:] = data_cleaning(" ".join(tokens[1:])).split()
#tokens = data_cleaning(" ".join(tokens)).split()
if to_BW == "True":
for i in range(len(tokens[1:])):
tokens[i+1] = fromBuckWalter(tokens[i+1])
new_data.append(" ".join(tokens))
else:
new_data.append(" ".join(tokens))
df = pd.DataFrame(data=new_data)
df.to_csv(output_file, sep = '\n', header=False, index=False)
if __name__ == "__main__":
main()

183
egs/iwslt22_ta/ST/prepare.sh Executable file
View File

@ -0,0 +1,183 @@
#!/usr/bin/env bash
# Copyright 2023 Johns Hopkins University (Amir Hussein)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
set -eou pipefail
nj=20
stage=1
stop_stage=4
# We assume dl_dir (download dir) contains the following
# directories and files.
#
# - $dl_dir/iwslt_ta
#
# You can download the data from
#
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
#
# Note: iwslt_ta is not available for direct
# download, "Download IWSLT Tunisian from LDC LDC2022E01. This script assumes you prepared the stm files"
#"Check the instructions to prepare the stm files from the raw data here https://github.com/kevinduh/iwslt22-dialect"
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
1000
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/iwslt_ta,
# you can create a symlink
#
# ln -sfv /path/to/iwslt_ta $dl_dir/iwslt_ta
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
fbank=data/fbank
manifests=data/manifests
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare iwslt manifest"
# We assume that you have downloaded the iwslt_ta corpus to $dl_dir/iwslt_ta
# Also git clone https://github.com/kevinduh/iwslt22-dialect
if [ ! -d "iwslt22-dialect" ]; then
echo "Splits directory (iwslt22-dialect) does not exist"
echo "Run git clone https://github.com/kevinduh/iwslt22-dialect"
exit 1
fi
manifests=data/manifests
mkdir -p $manifests
lhotse prepare iwslt_ta $dl_dir/iwslt_ta iwslt22-dialect data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p $manifests
lhotse prepare musan $dl_dir/musan $manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank features"
mkdir -p ${fbank}
python local/compute_fbank_gpu.py --num-splits 20
log "Combine features from train splits (may take ~1h)"
if [ ! -f $manifests/cuts_train.jsonl.gz ]; then
pieces=$(find $manifests -name "cuts_train_[0-9]*.jsonl.gz")
lhotse combine $pieces $manifests/cuts_train.jsonl.gz
fi
gunzip -c $manifests/cuts_train.jsonl.gz | shuf | gzip -c > ${fbank}/cuts_train_shuf.jsonl.gz
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p ${fbank}
./local/compute_fbank_musan.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir_src=data/lang_phone_src
lang_dir_tgt=data/lang_phone_tgt
if [ ! -f download/lm/train_src/transcript_words.txt ] || [ ! -f download/lm/train_tgt/transcript_words.txt ]; then
# export train text file to build grapheme lexicon
log "Creating transcripts in download/lm/train from lhotse cuts"
mkdir -p download/lm/train_src
mkdir -p download/lm/train_tgt
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --src-langdir download/lm/train_src --tgt-langdir download/lm/train_tgt
fi
mkdir -p $lang_dir_src
mkdir -p $lang_dir_tgt
log "Prepare lexicon"
./local/prep_lexicon.sh download/lm/train_src download/lm/train_tgt
python local/prepare_lexicon.py $dl_dir/lm/train_src/words.txt $dl_dir/lm/train_src/lexicon.txt
python local/prepare_lexicon.py $dl_dir/lm/train_tgt/words.txt $dl_dir/lm/train_tgt/lexicon.txt
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/train_src/lexicon.txt |
sort | uniq > $lang_dir_src/lexicon.txt
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/train_tgt/lexicon.txt |
sort | uniq > $lang_dir_tgt/lexicon.txt
if [ ! -f $lang_dir_src/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir_src
fi
if [ ! -f $lang_dir_tgt/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir_tgt
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
srctag=ta
tgttag=en
for vocab_size in ${vocab_sizes[@]}; do
src_lang_dir=data/lang_bpe_${srctag}_${vocab_size}
tgt_lang_dir=data/lang_bpe_${tgttag}_${vocab_size}
mkdir -p ${src_lang_dir}
mkdir -p ${tgt_lang_dir}
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone_src/words.txt $src_lang_dir
cp data/lang_phone_tgt/words.txt $tgt_lang_dir
if [ ! -f $src_lang_dir/transcript_words.txt ] || [ ! -f $tgt_lang_dir/transcript_words.txt ]; then
log "Generate data for ${srctag} and ${tgttag} BPE training from data/fbank/cuts_train_shuf.jsonl.gz"
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --src-langdir ${src_lang_dir} --tgt-langdir ${tgt_lang_dir}
fi
for lang_dir in $src_lang_dir $tgt_lang_dir; do
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi
done
done
fi

View File

@ -0,0 +1,396 @@
# Copyright 2023 Amir Hussein
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2Speech2textTranslationDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class IWSLTDialectSTDataModule:
"""
DataModule for k2 ST experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2Speech2textTranslationDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2Speech2textTranslationDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=8,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2Speech2textTranslationDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,948 @@
#!/usr/bin/env python3
# Copyright 2023 Johns Hopkins (authors: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(3) modified beam search
./pruned_transducer_stateless5/decode_st.py \
--epoch 12 \
--beam-size 20 \
--avg 3 \
--exp-dir ./pruned_transducer_stateless5/exp_st \
--max-duration 400 \
--decoding-method modified_beam_search \
--max-sym-per-frame 1 \
--num-encoder-layers 12 \
--dim-feedforward 1024 \
--nhead 8 \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--use-averaged-model true
"""
import argparse
import logging
import math
import pdb
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from lhotse.qa import validate_cut
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import IWSLTDialectSTDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion,
)
from train_st import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
setup_logger,
store_translations,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_ta_1000/bpe.model",
help="Path to source data BPE model",
)
parser.add_argument(
"--bpe-tgt-model",
type=str,
default="data/lang_bpe_en_1000/bpe.model",
help="Path to target data BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_ta_1000",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--lang-tgt-dir",
type=Path,
default="data/lang_bpe_en_1000",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens)
hyps = []
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def remove_short_and_long_utt(c):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 0.5 or c.duration > 30.0:
#logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
#)
return False
if c.supervisions == []:
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsamplin
return True
# def remove_seg(c):
# if c.supervisions[0].id != 'fla_0102_1_0B_00107':
# return True
# else:
# return False
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
texts_tgt = batch["supervisions"]["tgt_text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
#breakpoint()
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text, ref_text_tgt in zip(cut_ids, hyps, texts, texts_tgt):
ref_words = ref_text.split()
ref_words_tgt = ref_text_tgt.split()
this_batch.append((cut_id, ref_words, ref_words_tgt, hyp_words))
results[name].extend(this_batch)
#breakpoint()
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_translations(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_translations(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@torch.no_grad()
def main():
parser = get_parser()
IWSLTDialectSTDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_tgt_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
rnn_lm_model = None
rnn_lm_scale = params.rnn_lm_scale
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_tgt_dir)
word_table = lexicon.word_table
lg_filename = params.lang_tgt_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
iwslt_ta = IWSLTDialectSTDataModule(args)
test_cuts = iwslt_ta.test_cuts()
dev_cuts = iwslt_ta.dev_cuts()
# lev_test_cuts = lev_test_cuts.filter(remove_short_and_long_utt)
# # lev_test_cuts = lev_test_cuts.filter(remove_seg)
# gulf_test_cuts = gulf_test_cuts.filter(remove_short_and_long_utt)
# egy_test_cuts = egy_test_cuts.filter(remove_short_and_long_utt)
# egy_h5_cuts = egy_sup_cuts.filter(remove_short_and_long_utt)
# egy_sup_cuts = egy_h5_cuts.filter(remove_short_and_long_utt)
test_dl = iwslt_ta.test_dataloaders(test_cuts)
dev_dl = iwslt_ta.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_all_dl = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_all_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/export.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
./pruned_transducer_stateless5/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py
"""
from train import get_params, get_transducer_model
def test_model_1():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 24
params.dim_feedforward = 1536 # 384 * 4
params.encoder_dim = 384
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
def test_model_M():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 18
params.dim_feedforward = 1024
params.encoder_dim = 256
params.nhead = 4
params.decoder_dim = 512
params.joiner_dim = 512
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
# test_model_1()
test_model_M()
if __name__ == "__main__":
main()

Some files were not shown because too many files have changed in this diff Show More