mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
IWSLT-Ta ASR/ST (#1362)
This is a pull request for Dialectal IWSLT-Tunisian 2022 shared task https://iwslt.org/2022/dialect ASR and ST recipes.
This commit is contained in:
parent
855536d355
commit
729a5ba3ec
26
egs/iwslt22_ta/ASR/README.md
Normal file
26
egs/iwslt22_ta/ASR/README.md
Normal file
@ -0,0 +1,26 @@
|
||||
# IWSLT_Ta
|
||||
|
||||
The IWSLT Tunisian dataset is a 3-way parallel dataset consisting of approximately 160 hours
|
||||
and 200,000 lines of aligned audio, Tunisian transcripts, and English translations. This dataset
|
||||
comprises conversational telephone speech recorded at a sampling rate of 8kHz. The train, dev,
|
||||
and test1 splits of the iwslt2022 shared task correspond to catalog number LDC2022E01. Please
|
||||
note that access to this data requires an LDC subscription from your institution.To obtain this
|
||||
dataset, you should download the predefined splits by running the following command:
|
||||
git clone https://github.com/kevinduh/iwslt22-dialect.git. For more detailed information about
|
||||
the shared task, please refer to the task paper available at this link:
|
||||
https://aclanthology.org/2022.iwslt-1.10/.
|
||||
|
||||
## Stateless Pruned Transducer Performance Record (after 20 epochs)
|
||||
|
||||
| Decoding method | dev WER | test WER | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 47.6 | 51.2 | --epoch 20, --avg 10 |
|
||||
|
||||
## Zipformer Performance Record (after 20 epochs)
|
||||
|
||||
| Decoding method | dev WER | test WER | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 40.8 | 44.4 | --epoch 20, --avg 10 |
|
||||
|
||||
|
||||
See [RESULTS](RESULTS.md) for details.
|
||||
110
egs/iwslt22_ta/ASR/RESULTS.md
Normal file
110
egs/iwslt22_ta/ASR/RESULTS.md
Normal file
@ -0,0 +1,110 @@
|
||||
# Results
|
||||
|
||||
|
||||
|
||||
### IWSLT Tunisian training results (Stateless Pruned Transducer)
|
||||
|
||||
#### 2023-06-01
|
||||
|
||||
|
||||
| Decoding method | dev WER | test WER | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 47.6 | 51.2 | --epoch 20, --avg 13 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
|
||||
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--max-duration 300 \
|
||||
--num-buckets 50
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/yBijWJSPSGuBqMwTZ509lA/
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
for method in modified_beam_search; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 15 \
|
||||
--beam-size 20 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 400 \
|
||||
--decoding-method modified_beam_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--num-encoder-layers 12 \
|
||||
--dim-feedforward 1024 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--use-averaged-model true
|
||||
done
|
||||
```
|
||||
|
||||
### IWSLT Tunisian training results (Zipformer)
|
||||
|
||||
#### 2023-06-01
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/AmirHussein/zipformer-iwslt22-Ta>
|
||||
|
||||
|
||||
|
||||
| Decoding method | dev WER | test WER | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 40.8 | 44.1 | --epoch 20, --avg 13 |
|
||||
|
||||
To reproduce the above result, use the following commands for training:
|
||||
|
||||
# Note: the model was trained on V-100 32GB GPU
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
--causal 0 \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||
--encoder-dim 192,256,384,512,384,256 \
|
||||
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||
--max-duration 800 \
|
||||
--prune-range 10
|
||||
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
|
||||
```
|
||||
for method in modified_beam_search; do
|
||||
./zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--beam-size 20 \
|
||||
--avg 13 \
|
||||
--exp-dir ./zipformer/exp\
|
||||
--max-duration 800 \
|
||||
--decoding-method $method \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||
--encoder-dim 192,256,384,512,384,256 \
|
||||
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||
--use-averaged-model true
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
0
egs/iwslt22_ta/ASR/local/__init__.py
Executable file
0
egs/iwslt22_ta/ASR/local/__init__.py
Executable file
1
egs/iwslt22_ta/ASR/local/cer.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/cer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/local/cer.py
|
||||
168
egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py
Executable file
168
egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py
Executable file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
# Johns Hopkins University (authors: Amir Hussein)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the MGB2 dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of splits for the train set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Start index of the train set split.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop index of the train set split.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="If set, only compute features for the dev and val set.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_gpu(args):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
num_jobs = os.cpu_count()
|
||||
num_mel_bins = 80
|
||||
sampling_rate = 16000
|
||||
sr = 16000
|
||||
|
||||
dataset_parts = ("dev", "test1") if args.test else ("train", "test1", "dev")
|
||||
manifests = read_manifests_if_cached(
|
||||
prefix="iwslt-ta", dataset_parts=dataset_parts, output_dir=src_dir
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
if sr != None:
|
||||
logging.info(f"Resampling to {sr}")
|
||||
cut_set = cut_set.resample(sr)
|
||||
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False,
|
||||
keep_all_channels=False)
|
||||
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.to_eager()
|
||||
chunk_size = len(cut_set) // args.num_splits
|
||||
cut_sets = cut_set.split_lazy(
|
||||
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
|
||||
chunk_size=chunk_size,)
|
||||
start = args.start
|
||||
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
|
||||
num_digits = len(str(args.num_splits))
|
||||
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
|
||||
logging.info(f"Processing train split {i}")
|
||||
cs = cut_sets[i].compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"feats_train_{idx}",
|
||||
batch_duration=1000,
|
||||
num_workers=8,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cs.to_file(cuts_train_idx_path)
|
||||
else:
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"feats_{partition}",
|
||||
batch_duration=1000,
|
||||
num_workers=10,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
|
||||
compute_fbank_gpu(args)
|
||||
1
egs/iwslt22_ta/ASR/local/compute_fbank_musan.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
||||
1
egs/iwslt22_ta/ASR/local/cuts_validate.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/cuts_validate.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/local/cuts_validate.py
|
||||
1
egs/iwslt22_ta/ASR/local/display_manifest_statistics.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/display_manifest_statistics.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/display_manifest_statistics.py
|
||||
1
egs/iwslt22_ta/ASR/local/filter_cuts.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/filter_cuts.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/filter_cuts.py
|
||||
1
egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/generate_unique_lexicon.py
|
||||
1
egs/iwslt22_ta/ASR/local/prep_lexicon.sh
Symbolic link
1
egs/iwslt22_ta/ASR/local/prep_lexicon.sh
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/local/prep_lexicon.sh
|
||||
1
egs/iwslt22_ta/ASR/local/prepare_lang.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang.py
|
||||
1
egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
||||
1
egs/iwslt22_ta/ASR/local/prepare_transcripts.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/prepare_transcripts.py
Symbolic link
@ -0,0 +1 @@
|
||||
/exp/ahussein/tmp/icefall/egs/iwslt22_ta/ST/local/prepare_transcripts.py
|
||||
1
egs/iwslt22_ta/ASR/local/test_prepare_lang.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/test_prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/test_prepare_lang.py
|
||||
1
egs/iwslt22_ta/ASR/local/train_bpe_model.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
||||
1
egs/iwslt22_ta/ASR/local/validate_manifest.py
Symbolic link
1
egs/iwslt22_ta/ASR/local/validate_manifest.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/validate_manifest.py
|
||||
165
egs/iwslt22_ta/ASR/prepare.sh
Executable file
165
egs/iwslt22_ta/ASR/prepare.sh
Executable file
@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=20
|
||||
stage=1
|
||||
stop_stage=4
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files.
|
||||
#
|
||||
# - $dl_dir/iwslt_ta
|
||||
#
|
||||
# You can download the data from
|
||||
#
|
||||
#
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
#
|
||||
# Note: iwslt_ta is not available for direct
|
||||
# download, "Download IWSLT Tunisian from LDC LDC2022E01. This script assumes you prepared the stm files"
|
||||
#"Check the instructions to prepare the stm files from the raw data here https://github.com/kevinduh/iwslt22-dialect"
|
||||
|
||||
dl_dir=$PWD/download
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
1000
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/iwslt_ta,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/iwslt_ta $dl_dir/iwslt_ta
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
fbank=data/fbank
|
||||
manifests=data/manifests
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare iwslt manifest"
|
||||
# We assume that you have downloaded the iwslt_ta corpus to $dl_dir/iwslt_ta
|
||||
# Also git clone https://github.com/kevinduh/iwslt22-dialect
|
||||
if [ ! -d "iwslt22-dialect" ]; then
|
||||
echo "Splits directory (iwslt22-dialect) does not exist"
|
||||
echo "Run git clone https://github.com/kevinduh/iwslt22-dialect"
|
||||
exit 1
|
||||
fi
|
||||
manifests=data/manifests
|
||||
mkdir -p $manifests
|
||||
|
||||
lhotse prepare iwslt_ta $dl_dir/iwslt_ta iwslt22-dialect data/manifests
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
mkdir -p $manifests
|
||||
lhotse prepare musan $dl_dir/musan $manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank features"
|
||||
mkdir -p ${fbank}
|
||||
python local/compute_fbank_gpu.py --num-splits 20
|
||||
|
||||
log "Combine features from train splits (may take ~1h)"
|
||||
if [ ! -f $manifests/cuts_train.jsonl.gz ]; then
|
||||
pieces=$(find $manifests -name "cuts_train_[0-9]*.jsonl.gz")
|
||||
lhotse combine $pieces $manifests/cuts_train.jsonl.gz
|
||||
fi
|
||||
gunzip -c $manifests/cuts_train.jsonl.gz | shuf | gzip -c > ${fbank}/cuts_train_shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p ${fbank}
|
||||
./local/compute_fbank_musan.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
lang_dir=data/lang_phone
|
||||
|
||||
if [ ! -f download/lm/train/transcript_words.txt ]; then
|
||||
# export train text file to build grapheme lexicon
|
||||
log "Creating transcripts in download/lm/train from lhotse cuts"
|
||||
mkdir -p download/lm/train
|
||||
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --langdir download/lm/train
|
||||
fi
|
||||
mkdir -p $lang_dir
|
||||
|
||||
log "Prepare lexicon"
|
||||
|
||||
./local/prep_lexicon.sh download/lm/train
|
||||
python local/prepare_lexicon.py $dl_dir/lm/train/words.txt $dl_dir/lm/train/lexicon.txt
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/train/lexicon.txt |
|
||||
sort | uniq > $lang_dir/lexicon.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang"
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p ${lang_dir}
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone_src/words.txt $lang_dir
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
log "Generate Tunisian text for BPE training from data/fbank/cuts_train_shuf.jsonl.gz"
|
||||
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --langdir ${ang_dir}
|
||||
fi
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
done
|
||||
done
|
||||
fi
|
||||
@ -0,0 +1,396 @@
|
||||
# Copyright 2023 Amir Hussein
|
||||
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2Speech2textTranslationDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class IWSLTDialectSTDataModule:
|
||||
|
||||
"""
|
||||
DataModule for k2 ST experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=8,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2Speech2textTranslationDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz")
|
||||
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
|
||||
1809
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/conformer.py
Normal file
1809
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
960
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode.py
Executable file
960
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode.py
Executable file
@ -0,0 +1,960 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Johns Hopkins (authors: Amir Hussein)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless5/decode_asr.py \
|
||||
--epoch 15 \
|
||||
--beam-size 20 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp_asr \
|
||||
--max-duration 400 \
|
||||
--decoding-method modified_beam_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--num-encoder-layers 12 \
|
||||
--dim-feedforward 1024 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--use-averaged-model true
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import pdb
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from lhotse.qa import validate_cut
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import IWSLTDialectSTDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_rnnlm_shallow_fusion,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_ta_1000/bpe.model",
|
||||
help="Path to source data BPE model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpe-tgt-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_en_1000/bpe.model",
|
||||
help="Path to target data BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/ang_bpe_ta_1000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-tgt-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_en_1000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
def remove_short_and_long_utt(c):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 0.5 or c.duration > 30.0:
|
||||
#logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
#)
|
||||
return False
|
||||
if c.supervisions == []:
|
||||
return False
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
# where T is the number of feature frames after subsampling
|
||||
# and S is the number of tokens in the utterance
|
||||
|
||||
# In ./conformer.py, the conv module uses the following expression
|
||||
# for subsamplin
|
||||
|
||||
return True
|
||||
|
||||
# def remove_seg(c):
|
||||
# if c.supervisions[0].id != 'fla_0102_1_0B_00107':
|
||||
# return True
|
||||
# else:
|
||||
# return False
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
logging.info(f"Decoding {batch_idx}-th batch")
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir /
|
||||
f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir /
|
||||
f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
IWSLTDialectSTDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_rnnlm_shallow_fusion",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
rnn_lm_model = None
|
||||
rnn_lm_scale = params.rnn_lm_scale
|
||||
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
assert params.rnn_lm_avg == 1
|
||||
|
||||
load_checkpoint(
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
rnn_lm_model,
|
||||
)
|
||||
rnn_lm_model.to(device)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
iwslt_ta = IWSLTDialectSTDataModule(args)
|
||||
|
||||
test_cuts = iwslt_ta.test_cuts()
|
||||
dev_cuts = iwslt_ta.dev_cuts()
|
||||
|
||||
# lev_test_cuts = lev_test_cuts.filter(remove_short_and_long_utt)
|
||||
# # lev_test_cuts = lev_test_cuts.filter(remove_seg)
|
||||
# gulf_test_cuts = gulf_test_cuts.filter(remove_short_and_long_utt)
|
||||
# egy_test_cuts = egy_test_cuts.filter(remove_short_and_long_utt)
|
||||
# egy_h5_cuts = egy_sup_cuts.filter(remove_short_and_long_utt)
|
||||
# egy_sup_cuts = egy_h5_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
test_dl = iwslt_ta.test_dataloaders(test_cuts)
|
||||
dev_dl = iwslt_ta.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_all_dl = [test_dl, dev_dl]
|
||||
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_all_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
rnnlm=rnn_lm_model,
|
||||
rnnlm_scale=rnn_lm_scale,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py
|
||||
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
|
||||
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/export.py
|
||||
1306
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune.py
Executable file
1306
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
1323
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune_loadmodel.py
Executable file
1323
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune_loadmodel.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
|
||||
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/model.py
|
||||
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
|
||||
352
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/pretrained.py
Executable file
352
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/pretrained.py
Executable file
@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless5/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py
Symbolic link
1
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
|
||||
65
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/test_model.py
Executable file
65
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/test_model.py
Executable file
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model_1():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = 24
|
||||
params.dim_feedforward = 1536 # 384 * 4
|
||||
params.encoder_dim = 384
|
||||
model = get_transducer_model(params)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
|
||||
def test_model_M():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = 18
|
||||
params.dim_feedforward = 1024
|
||||
params.encoder_dim = 256
|
||||
params.nhead = 4
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = get_transducer_model(params)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
def main():
|
||||
# test_model_1()
|
||||
test_model_M()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1301
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py
Executable file
1301
egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/iwslt22_ta/ASR/shared
Symbolic link
1
egs/iwslt22_ta/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
||||
396
egs/iwslt22_ta/ASR/zipformer/asr_datamodule.py
Normal file
396
egs/iwslt22_ta/ASR/zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,396 @@
|
||||
# Copyright 2023 Amir Hussein
|
||||
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2Speech2textTranslationDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class IWSLTDialectSTDataModule:
|
||||
|
||||
"""
|
||||
DataModule for k2 ST experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=8,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2Speech2textTranslationDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz")
|
||||
1
egs/iwslt22_ta/ASR/zipformer/beam_search.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/beam_search.py
|
||||
852
egs/iwslt22_ta/ASR/zipformer/decode.py
Executable file
852
egs/iwslt22_ta/ASR/zipformer/decode.py
Executable file
@ -0,0 +1,852 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Johns Hopkins University (Author: Amir Hussein)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import IWSLTDialectSTDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train_asr import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_ta_1000/bpe.model",
|
||||
help="Path to source data BPE model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpe-tgt-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_en_1000/bpe.model",
|
||||
help="Path to target data BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/ang_bpe_ta_1000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-tgt-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_en_1000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.causal:
|
||||
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, pad_len),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
x, x_lens = model.encoder_embed(feature, feature_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
use_hat=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
IWSLTDialectSTDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# use predefined parameters that were used during the training
|
||||
# params.num_encoder_layers = "2,2,2,2,2,2"
|
||||
# params.feedforward_dim = "256,512,768,1024,768,512"
|
||||
# params.encoder_dim = "128,256,256,512,256,256"
|
||||
# params.encoder_unmasked_dim = "64,128,128,256,128,128"
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
iwslt_ta = IWSLTDialectSTDataModule(args)
|
||||
|
||||
test_cuts = iwslt_ta.test_cuts()
|
||||
dev_cuts = iwslt_ta.dev_cuts()
|
||||
test_dl = iwslt_ta.test_dataloaders(test_cuts)
|
||||
dev_dl = iwslt_ta.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_all_dl = [test_dl, dev_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_all_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/iwslt22_ta/ASR/zipformer/decoder.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/decoder.py
|
||||
43
egs/iwslt22_ta/ASR/zipformer/encoder_interface.py
Normal file
43
egs/iwslt22_ta/ASR/zipformer/encoder_interface.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EncoderInterface(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
||||
containing the input features.
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames
|
||||
in `x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||
containing unnormalized probabilities, i.e., the output of a
|
||||
linear layer.
|
||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
||||
the number of frames in `encoder_out` before padding.
|
||||
"""
|
||||
raise NotImplementedError("Please implement it in a subclass")
|
||||
1
egs/iwslt22_ta/ASR/zipformer/export.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/export.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/generate_averaged_model.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/jit_pretrained.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/jit_pretrained_streaming.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/joiner.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/joiner.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/model.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/model.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/optim.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/pretrained.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ST/zipformer/pretrained.py
|
||||
176
egs/iwslt22_ta/ASR/zipformer/profile.py
Executable file
176
egs/iwslt22_ta/ASR/zipformer/profile.py
Executable file
@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage: ./zipformer/profile.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BiasNorm
|
||||
from train import (
|
||||
get_encoder_embed,
|
||||
get_encoder_model,
|
||||
get_joiner_model,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
)
|
||||
from zipformer import BypassModule
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _bias_norm_flops_compute(module, input, output):
|
||||
assert len(input) == 1, len(input)
|
||||
# estimate as layer_norm, see icefall/profiler.py
|
||||
flops = input[0].numel() * 5
|
||||
module.__flops__ += int(flops)
|
||||
|
||||
|
||||
def _swoosh_module_flops_compute(module, input, output):
|
||||
# For SwooshL and SwooshR modules
|
||||
assert len(input) == 1, len(input)
|
||||
# estimate as swish/silu, see icefall/profiler.py
|
||||
flops = input[0].numel()
|
||||
module.__flops__ += int(flops)
|
||||
|
||||
|
||||
def _bypass_module_flops_compute(module, input, output):
|
||||
# For Bypass module
|
||||
assert len(input) == 2, len(input)
|
||||
flops = input[0].numel() * 2
|
||||
module.__flops__ += int(flops)
|
||||
|
||||
|
||||
MODULE_HOOK_MAPPING = {
|
||||
BiasNorm: _bias_norm_flops_compute,
|
||||
BypassModule: _bypass_module_flops_compute,
|
||||
}
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""A Wrapper for encoder, encoder_embed, and encoder_proj"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: nn.Module,
|
||||
encoder_embed: nn.Module,
|
||||
encoder_proj: nn.Module,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(
|
||||
self, feature: Tensor, feature_lens: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
logits = self.encoder_proj(encoder_out)
|
||||
|
||||
return logits, encoder_out_lens
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
# We only profile the encoder part
|
||||
model = Model(
|
||||
encoder=get_encoder_model(params),
|
||||
encoder_embed=get_encoder_embed(params),
|
||||
encoder_proj=get_joiner_model(params).encoder_proj,
|
||||
)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# for 30-second input
|
||||
B, T, D = 1, 3000, 80
|
||||
feature = torch.ones(B, T, D, dtype=torch.float32).to(device)
|
||||
feature_lens = torch.full((B,), T, dtype=torch.int64).to(device)
|
||||
|
||||
flops, params = get_model_profile(
|
||||
model=model,
|
||||
args=(feature, feature_lens),
|
||||
module_hoop_mapping=MODULE_HOOK_MAPPING,
|
||||
)
|
||||
logging.info(f"For the encoder part, params: {params}, flops: {flops}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
||||
1
egs/iwslt22_ta/ASR/zipformer/scaling.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/scaling.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/scaling_converter.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/scaling_converter.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/streaming_beam_search.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/streaming_decode.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/streaming_decode.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/streaming_decode.py
|
||||
1
egs/iwslt22_ta/ASR/zipformer/subsampling.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/subsampling.py
|
||||
1417
egs/iwslt22_ta/ASR/zipformer/train.py
Executable file
1417
egs/iwslt22_ta/ASR/zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/iwslt22_ta/ASR/zipformer/zipformer.py
Symbolic link
1
egs/iwslt22_ta/ASR/zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/zipformer.py
|
||||
28
egs/iwslt22_ta/ST/README.md
Normal file
28
egs/iwslt22_ta/ST/README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# IWSLT_Ta
|
||||
|
||||
The IWSLT Tunisian dataset is a 3-way parallel dataset consisting of approximately 160 hours
|
||||
and 200,000 lines of aligned audio, Tunisian transcripts, and English translations. This dataset
|
||||
comprises conversational telephone speech recorded at a sampling rate of 8kHz. The train, dev,
|
||||
and test1 splits of the iwslt2022 shared task correspond to catalog number LDC2022E01. Please
|
||||
note that access to this data requires an LDC subscription from your institution.To obtain this
|
||||
dataset, you should download the predefined splits by running the following command:
|
||||
git clone https://github.com/kevinduh/iwslt22-dialect.git. For more detailed information about
|
||||
the shared task, please refer to the task paper available at this link:
|
||||
https://aclanthology.org/2022.iwslt-1.10/.
|
||||
|
||||
## Stateless Pruned Transducer Performance Record (after 20 epochs)
|
||||
|
||||
| Decoding method | dev Bleu | test Bleu | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 13, beam(10), pruned range 5 |
|
||||
|
||||
## Zipformer Performance Record (after 20 epochs)
|
||||
|
||||
| Decoding method | dev Bleu | test Bleu | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 13, beam(10),pruned range 5 |
|
||||
| modified beam search | 15.5 | 13 | --epoch 20, --avg 13, beam(20),pruned range 5 |
|
||||
| modified beam search | 17.9 | 14.9 | --epoch 20, --avg 13, beam(20), pruned range 10 |
|
||||
|
||||
|
||||
See [RESULTS](/egs/iwslt_ta/ST/RESULTS.md) for details.
|
||||
125
egs/iwslt22_ta/ST/RESULTS.md
Normal file
125
egs/iwslt22_ta/ST/RESULTS.md
Normal file
@ -0,0 +1,125 @@
|
||||
# Results
|
||||
|
||||
|
||||
### IWSLT Tunisian training results (Stateless Pruned Transducer)
|
||||
|
||||
#### 2023-06-01
|
||||
|
||||
|
||||
| Decoding method | dev Bleu | test Bleu | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 10, beam(10), pruned range 5 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
|
||||
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--max-duration 300 \
|
||||
--bucketing-sampler 1\
|
||||
--num-buckets 50
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/YnzQNCVDSxCvP1onrCzg9A/
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
for method in modified_beam_search; do
|
||||
for epoch in 15 20; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch $epoch \
|
||||
--beam-size 20 \
|
||||
--avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp_st \
|
||||
--max-duration 300 \
|
||||
--decoding-method $method \
|
||||
--max-sym-per-frame 1 \
|
||||
--num-encoder-layers 12 \
|
||||
--dim-feedforward 1024 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--use-averaged-model true
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
### IWSLT Tunisian training results (Zipformer)
|
||||
|
||||
#### 2023-06-01
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/AmirHussein/zipformer-iwslt22-Ta>
|
||||
|
||||
|
||||
|
||||
| Decoding method | dev Bleu | test Bleu | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 10, beam(10),pruned range 5 |
|
||||
| modified beam search | 15.5 | 13 | --epoch 20, --avg 10, beam(20),pruned range 5 |
|
||||
| modified beam search | 18.2 | 14.8 | --epoch 20, --avg 10, beam(20), pruned range 10 |
|
||||
|
||||
|
||||
|
||||
To reproduce the above result, use the following commands for training:
|
||||
|
||||
# Note: the model was trained on V-100 32GB GPU
|
||||
# ST medium model 42.5M prune-range 10
|
||||
```
|
||||
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 25 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-st-medium \
|
||||
--causal 0 \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||
--encoder-dim 192,256,384,512,384,256 \
|
||||
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||
--max-duration 800 \
|
||||
--prune-range 10 \
|
||||
--warm-step 5000 \
|
||||
--lr-epochs 8 \
|
||||
--base-lr 0.055 \
|
||||
--use-hat False
|
||||
|
||||
```
|
||||
|
||||
|
||||
The decoding command is:
|
||||
|
||||
```
|
||||
for method in modified_beam_search; do
|
||||
for epoch in 15 20; do
|
||||
./zipformer/decode.py \
|
||||
--epoch $epoch \
|
||||
--beam-size 20 \
|
||||
--avg 10 \
|
||||
--exp-dir ./zipformer/exp-st-medium-prun10 \
|
||||
--max-duration 800 \
|
||||
--decoding-method $method \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||
--encoder-dim 192,256,384,512,384,256 \
|
||||
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||
--context-size 2 \
|
||||
--use-averaged-model true \
|
||||
--use-hat False
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
0
egs/iwslt22_ta/ST/local/__init__.py
Executable file
0
egs/iwslt22_ta/ST/local/__init__.py
Executable file
58
egs/iwslt22_ta/ST/local/cer.py
Normal file
58
egs/iwslt22_ta/ST/local/cer.py
Normal file
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/python
|
||||
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
This script computes CER for the decodings generated by icefall recipe
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import jiwer
|
||||
import os
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dec-file",
|
||||
type=str,
|
||||
help="file with decoded text"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def cer_(file):
|
||||
hyp = []
|
||||
ref = []
|
||||
cer_results = 0
|
||||
ref_lens = 0
|
||||
with open(file, 'r', encoding='utf-8') as dec:
|
||||
|
||||
for line in dec:
|
||||
id, target = line.split('\t')
|
||||
id = id[0:-2]
|
||||
target, txt = target.split("=")
|
||||
if target == 'ref':
|
||||
words = txt.strip().strip('[]').split(', ')
|
||||
word_list = [word.strip("'") for word in words]
|
||||
ref.append(" ".join(word_list))
|
||||
elif target == 'hyp':
|
||||
words = txt.strip().strip('[]').split(', ')
|
||||
word_list = [word.strip("'") for word in words]
|
||||
hyp.append(" ".join(word_list))
|
||||
for h, r in zip(hyp, ref):
|
||||
#breakpoint()
|
||||
cer_results += (jiwer.cer(r, h)*len(r))
|
||||
ref_lens += len(r)
|
||||
print(os.path.basename(file))
|
||||
print(cer_results/ref_lens)
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
parse = get_args()
|
||||
args = parse.parse_args()
|
||||
cer_(args.dec_file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
171
egs/iwslt22_ta/ST/local/compute_fbank_gpu.py
Executable file
171
egs/iwslt22_ta/ST/local/compute_fbank_gpu.py
Executable file
@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
# Johns Hopkins University (authors: Amir Hussein)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the MGB2 dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of splits for the train set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Start index of the train set split.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop index of the train set split.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="If set, only compute features for the dev and val set.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_gpu(args):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = os.cpu_count()
|
||||
num_mel_bins = 80
|
||||
sampling_rate = 16000
|
||||
sr = 16000
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
"test1",
|
||||
"dev",
|
||||
)
|
||||
manifests = read_manifests_if_cached(
|
||||
prefix="iwslt-ta", dataset_parts=dataset_parts, output_dir=src_dir
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
if sr != None:
|
||||
logging.info(f"Resampling to {sr}")
|
||||
cut_set = cut_set.resample(sr)
|
||||
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False,
|
||||
keep_all_channels=False)
|
||||
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.to_eager()
|
||||
chunk_size = len(cut_set) // args.num_splits
|
||||
cut_sets = cut_set.split_lazy(
|
||||
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
|
||||
chunk_size=chunk_size,)
|
||||
start = args.start
|
||||
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
|
||||
num_digits = len(str(args.num_splits))
|
||||
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
|
||||
logging.info(f"Processing train split {i}")
|
||||
cs = cut_sets[i].compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"feats_train_{idx}",
|
||||
batch_duration=1000,
|
||||
num_workers=8,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cs.to_file(cuts_train_idx_path)
|
||||
else:
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"feats_{partition}",
|
||||
batch_duration=1000,
|
||||
num_workers=10,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
|
||||
compute_fbank_gpu(args)
|
||||
1
egs/iwslt22_ta/ST/local/compute_fbank_musan.py
Symbolic link
1
egs/iwslt22_ta/ST/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
||||
107
egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py
Executable file
107
egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py
Executable file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
"""
|
||||
Convert a transcript file containing words to a corpus file containing tokens
|
||||
for LM training with the help of a lexicon.
|
||||
|
||||
If the lexicon contains phones, the resulting LM will be a phone LM; If the
|
||||
lexicon contains word pieces, the resulting LM will be a word piece LM.
|
||||
|
||||
If a word has multiple pronunciations, the one that appears first in the lexicon
|
||||
is kept; others are removed.
|
||||
|
||||
If the input transcript is:
|
||||
|
||||
hello zoo world hello
|
||||
world zoo
|
||||
foo zoo world hellO
|
||||
|
||||
and if the lexicon is
|
||||
|
||||
<UNK> SPN
|
||||
hello h e l l o 2
|
||||
hello h e l l o
|
||||
world w o r l d
|
||||
zoo z o o
|
||||
|
||||
Then the output is
|
||||
|
||||
h e l l o 2 z o o w o r l d h e l l o 2
|
||||
w o r l d z o o
|
||||
SPN z o o w o r l d SPN
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from generate_unique_lexicon import filter_multiple_pronunications
|
||||
|
||||
from icefall.lexicon import read_lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="The input transcript file."
|
||||
"We assume that the transcript file consists of "
|
||||
"lines. Each line consists of space separated words.",
|
||||
)
|
||||
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
|
||||
parser.add_argument(
|
||||
"--oov", type=str, default="<UNK>", help="The OOV word."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_line(
|
||||
lexicon: Dict[str, List[str]], line: str, oov_token: str
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
lexicon:
|
||||
A dict containing pronunciations. Its keys are words and values
|
||||
are pronunciations (i.e., tokens).
|
||||
line:
|
||||
A line of transcript consisting of space(s) separated words.
|
||||
oov_token:
|
||||
The pronunciation of the oov word if a word in `line` is not present
|
||||
in the lexicon.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
s = ""
|
||||
words = line.strip().split()
|
||||
for i, w in enumerate(words):
|
||||
tokens = lexicon.get(w, oov_token)
|
||||
s += " ".join(tokens)
|
||||
s += " "
|
||||
print(s.strip())
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert Path(args.lexicon).is_file()
|
||||
assert Path(args.transcript).is_file()
|
||||
assert len(args.oov) > 0
|
||||
|
||||
# Only the first pronunciation of a word is kept
|
||||
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
|
||||
|
||||
lexicon = dict(lexicon)
|
||||
|
||||
assert args.oov in lexicon
|
||||
|
||||
oov_token = lexicon[args.oov]
|
||||
|
||||
with open(args.transcript) as f:
|
||||
for line in f:
|
||||
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
107
egs/iwslt22_ta/ST/local/cuts_validate.py
Executable file
107
egs/iwslt22_ta/ST/local/cuts_validate.py
Executable file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/python
|
||||
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
This script helps validating the prepared manifests (recordings, supervisions)
|
||||
and CutSets
|
||||
|
||||
"""
|
||||
from lhotse import RecordingSet, SupervisionSet, CutSet
|
||||
import argparse
|
||||
import logging
|
||||
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
|
||||
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sup",
|
||||
type=str,
|
||||
default="",
|
||||
help="Supervisions file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rec",
|
||||
type=str,
|
||||
default="",
|
||||
help="Recordings file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cut",
|
||||
type=str,
|
||||
default="",
|
||||
help="Cutset file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--savecut",
|
||||
type=str,
|
||||
default="",
|
||||
help="name of the cutset to be saved",
|
||||
)
|
||||
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
def valid_asr(cut):
|
||||
tol = 2e-3
|
||||
i=0
|
||||
total_dur = 0
|
||||
for c in cut:
|
||||
if c.supervisions != []:
|
||||
if c.supervisions[0].end > c.duration + tol:
|
||||
|
||||
logging.info(f"Supervision beyond the cut. Cut number: {i}")
|
||||
total_dur += c.duration
|
||||
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
||||
elif c.supervisions[0].start < -tol:
|
||||
logging.info(f"Supervision starts before the cut. Cut number: {i}")
|
||||
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
logging.info("Empty supervision")
|
||||
logging.info(f"id: {c.id}")
|
||||
i += 1
|
||||
logging.info(f"filtered duration: {total_dur}")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
if args.cut != "":
|
||||
cuts = CutSet.from_file(args.cut)
|
||||
else:
|
||||
|
||||
recordings = RecordingSet.from_file(args.rec)
|
||||
supervisions = SupervisionSet.from_file(args.sup)
|
||||
logging.info("Example from supervisions:")
|
||||
logging.info(supervisions[0])
|
||||
logging.info("Example from recordings")
|
||||
print(recordings[0])
|
||||
logging.info("Fixing manifests")
|
||||
recordings, supervisions = fix_manifests(recordings, supervisions)
|
||||
|
||||
logging.info("Validating manifests")
|
||||
validate_recordings_and_supervisions(recordings, supervisions)
|
||||
|
||||
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
|
||||
|
||||
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
|
||||
logging.info("Example from cut:")
|
||||
print(cuts[100])
|
||||
cuts.describe()
|
||||
logging.info("Validating manifests for ASR")
|
||||
valid_asr(cuts)
|
||||
if args.savecut != "":
|
||||
cuts.to_file(args.savecut)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/iwslt22_ta/ST/local/display_manifest_statistics.py
Symbolic link
1
egs/iwslt22_ta/ST/local/display_manifest_statistics.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/display_manifest_statistics.py
|
||||
1
egs/iwslt22_ta/ST/local/generate_unique_lexicon.py
Symbolic link
1
egs/iwslt22_ta/ST/local/generate_unique_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/generate_unique_lexicon.py
|
||||
21
egs/iwslt22_ta/ST/local/prep_lexicon.sh
Executable file
21
egs/iwslt22_ta/ST/local/prep_lexicon.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2022 QCRI (author: Amir Hussein)
|
||||
# Apache 2.0
|
||||
# This script prepares the graphemic lexicon.
|
||||
|
||||
dir=data/local/dict
|
||||
stage=0
|
||||
lang_dir_src=$1
|
||||
lang_dir_tgt=$2
|
||||
|
||||
cat $lang_dir_src/transcript_words.txt | tr -s " " "\n" | sort -u > $lang_dir_src/uniq_words
|
||||
cat $lang_dir_tgt/transcript_words.txt | tr -s " " "\n" | sort -u > $lang_dir_tgt/uniq_words
|
||||
|
||||
echo "$0: processing lexicon text and creating lexicon... $(date)."
|
||||
# remove vowels and rare alef wasla
|
||||
cat $lang_dir_src/uniq_words | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir_src/words.txt
|
||||
cat $lang_dir_tgt/uniq_words | sed -r '/^\s*$/d' | sort -u > $lang_dir_tgt/words.txt
|
||||
|
||||
|
||||
echo "$0: Lexicon preparation succeeded"
|
||||
1
egs/iwslt22_ta/ST/local/prepare_lang.py
Symbolic link
1
egs/iwslt22_ta/ST/local/prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang.py
|
||||
1
egs/iwslt22_ta/ST/local/prepare_lang_bpe.py
Symbolic link
1
egs/iwslt22_ta/ST/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
||||
39
egs/iwslt22_ta/ST/local/prepare_lexicon.py
Executable file
39
egs/iwslt22_ta/ST/local/prepare_lexicon.py
Executable file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||
# Apache 2.0
|
||||
|
||||
# This script prepares givel a column of words lexicon.
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Creates the list of characters and words in lexicon"""
|
||||
)
|
||||
parser.add_argument("input", type=str, help="""Input list of words file""")
|
||||
parser.add_argument("output", type=str, help="""output graphemic lexicon""")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
lex = {}
|
||||
args = get_args()
|
||||
with open(args.input, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
characters = list(line)
|
||||
characters = " ".join(
|
||||
["V" if char == "*" else char for char in characters]
|
||||
)
|
||||
lex[line] = characters
|
||||
|
||||
with open(args.output, "w", encoding="utf-8") as fp:
|
||||
for key in sorted(lex):
|
||||
fp.write(key + " " + lex[key] + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
66
egs/iwslt22_ta/ST/local/prepare_transcripts.py
Executable file
66
egs/iwslt22_ta/ST/local/prepare_transcripts.py
Executable file
@ -0,0 +1,66 @@
|
||||
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||
|
||||
#!/usr/bin/python
|
||||
"""
|
||||
This script prepares transcript_words.txt from cutset
|
||||
"""
|
||||
|
||||
from lhotse import CutSet
|
||||
import argparse
|
||||
import logging
|
||||
import pdb
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cut",
|
||||
type=str,
|
||||
default="",
|
||||
help="Cutset file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--src-langdir",
|
||||
type=str,
|
||||
default="",
|
||||
help="name of the source lang-dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgt-langdir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="name of the target lang-dir",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.info("Reading the cuts")
|
||||
cuts = CutSet.from_file(args.cut)
|
||||
if args.tgt_langdir != None:
|
||||
logging.info("Target dir is not None")
|
||||
langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)]
|
||||
else:
|
||||
langdirs = [Path(args.src_langdir)]
|
||||
|
||||
for langdir in langdirs:
|
||||
if not os.path.exists(langdir):
|
||||
os.makedirs(langdir)
|
||||
|
||||
with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt:
|
||||
for c in cuts:
|
||||
src_txt = c.supervisions[0].text
|
||||
tgt_txt = c.supervisions[0].custom['translated_text']['eng']
|
||||
src.write(src_txt + '\n')
|
||||
tgt.write(tgt_txt + '\n')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/iwslt22_ta/ST/local/test_prepare_lang.py
Symbolic link
1
egs/iwslt22_ta/ST/local/test_prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/test_prepare_lang.py
|
||||
1
egs/iwslt22_ta/ST/local/train_bpe_model.py
Symbolic link
1
egs/iwslt22_ta/ST/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
||||
243
egs/iwslt22_ta/ST/local/transcript_cleaning.py
Executable file
243
egs/iwslt22_ta/ST/local/transcript_cleaning.py
Executable file
@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 Kanari AI (Amir Hussein)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
import pdb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import re
|
||||
import string
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import pyarabic.number as number
|
||||
from pyarabic import araby
|
||||
|
||||
_unicode = u"\u0622\u0624\u0626\u0628\u062a\u062c\u06af\u062e\u0630\u0632\u0634\u0636\u0638\u063a\u0640\u0642\u0644\u0646\u0648\u064a\u064c\u064e\u0650\u0652\u0670\u067e\u0686\u0621\u0623\u0625\u06a4\u0627\u0629\u062b\u062d\u062f\u0631\u0633\u0635\u0637\u0639\u0641\u0643\u0645\u0647\u0649\u064b\u064d\u064f\u0651\u0671"
|
||||
_buckwalter = u"|&}btjGx*z$DZg_qlnwyNaio`PJ'><VApvHdrsSTEfkmhYFKu~{"
|
||||
_backwardMap = {ord(b):a for a,b in zip(_buckwalter, _unicode)}
|
||||
|
||||
|
||||
# def number2text(anumber):
|
||||
# """
|
||||
# Convert number to arabic words, for example convert 25 --> خمسة و عشرون
|
||||
# Example:
|
||||
# >>> number2text(523)
|
||||
# خمسمئة وثلاث وعشرون
|
||||
# @param anumber: input number
|
||||
# @type anumber: int
|
||||
# @return: number words
|
||||
# @rtype: unicode
|
||||
# """
|
||||
# # test if the given type is numeric(float or int
|
||||
# # if ok, convert it to string
|
||||
# if type(anumber) is int or type(anumber) is float:
|
||||
# anumber = str(anumber)
|
||||
# # if the given type is str/unicode, test if it's a valid number
|
||||
# elif type(anumber) is str or type(anumber) is unicode:
|
||||
# try:
|
||||
# a = float(anumber);
|
||||
# except ValueError:
|
||||
# return u"صفر"
|
||||
# # if the given number not a valid return 0
|
||||
# else:
|
||||
# return u"صفر"
|
||||
# arbn = number.ArNumbers()
|
||||
# arbn.set_feminine(2)
|
||||
# return arbn.int2str(anumber)
|
||||
|
||||
# return total
|
||||
|
||||
|
||||
def fromBuckWalter(s):
|
||||
return s.translate(_backwardMap)
|
||||
|
||||
def read_tsv(data_file):
|
||||
text_data = list()
|
||||
infile = open(data_file, encoding='utf-8')
|
||||
for line in infile:
|
||||
if not line.strip():
|
||||
continue
|
||||
text= line.split('\t')
|
||||
text_data.append(text)
|
||||
|
||||
return text_data
|
||||
|
||||
words_to_remove =['#غير_واضح', '#تلعثم', '#آ', '#أ', '#ال', '#بال', '#وال', ' #وب', '###تداخل', 'FRN','RFN',\
|
||||
' #سي', '#يي', '#هـ' , '#لل', '#بم', '#الش', '#آآ', ' #يت', '#وو', \
|
||||
'#ومش', '#ول', '#وسي', '#غير_معروف','#العا','#مطا', '#محم' ,' #ماث','#متطو', ' #نشا', '#أأ', '#آآآ', ' #استج']
|
||||
|
||||
|
||||
def normalizeArabic(text):
|
||||
# text = re.sub("[إأٱآا]", "ا", text)
|
||||
# text = re.sub("ى", "ي", text)
|
||||
# text = re.sub("ة", "ه", text)
|
||||
# text = re.sub("ئ", "ء", text)
|
||||
# text = re.sub("ؤ", "ء", text)
|
||||
text = re.sub(r"(ه){2,}", "ههه", text)
|
||||
text = re.sub(r"(أ){2,}", "ا", text)
|
||||
text = re.sub(r"(ا){2,}", "ا", text)
|
||||
text = re.sub(r"(آ){2,}", "ا", text)
|
||||
text = re.sub(r"(ص){2,}", "ص", text)
|
||||
text = re.sub(r"(و){2,}", "و", text)
|
||||
return text
|
||||
|
||||
def remove_special_words(text):
|
||||
# remove special words from the transcription text
|
||||
for word in words_to_remove:
|
||||
if word in text:
|
||||
text = text.replace(word, '')
|
||||
return text
|
||||
|
||||
def remove_hashes(text):
|
||||
return re.sub(r'#?', '', text)
|
||||
|
||||
def remove_english_characters(text):
|
||||
return re.sub(r'[^\u0600-\u06FF0-9\s]+', '', text)
|
||||
|
||||
def remove_diacritics(text):
|
||||
return re.sub(r'[\u064B-\u0652\u06D4\u0670\u0674\u06D5-\u06ED]+', '', text)
|
||||
|
||||
def remove_punctuations(text):
|
||||
""" This function removes all punctuations except the verbatim """
|
||||
|
||||
arabic_punctuations = '''﴿﴾‘`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ'''
|
||||
english_punctuations = string.punctuation
|
||||
all_punctuations = set(arabic_punctuations + english_punctuations)-{'@','%','.'} # remove all non verbatim punctuations
|
||||
|
||||
for p in all_punctuations:
|
||||
if p in text:
|
||||
text = text.replace(p, '')
|
||||
text = re.sub('\s+\.','',text) # keep only the "." that is part of a word: marsad@aljazeera.net . => marsad@aljazeera.net
|
||||
return text
|
||||
|
||||
def remove_extra_space(text):
|
||||
text = text.lower()
|
||||
text = re.sub('\s+', ' ', text)
|
||||
text = re.sub('\s+\.\s+', '.', text)
|
||||
return text
|
||||
|
||||
def remove_dot(text):
|
||||
|
||||
words = text.split()
|
||||
res = []
|
||||
for word in words:
|
||||
if word.replace('.','').isnumeric(): # remove the dot if it is not part of a number
|
||||
res.append(word)
|
||||
|
||||
else:
|
||||
word = re.sub('\s+\.','',word)
|
||||
res.append(word)
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def east_to_west_num(text):
|
||||
eastern_to_western = {"٠":"0","١":"1","٢":"2","٣":"3","٤":"4","٥":"5","٦":"6","٧":"7","٨":"8","٩":"9","٪":"%","_":" ","ڤ":"ف","|":" "}
|
||||
trans_string = str.maketrans(eastern_to_western)
|
||||
return text.translate(trans_string)
|
||||
|
||||
def remove_repeating_char(text):
|
||||
return re.sub(r'(.)\1+', r'\1', text)
|
||||
|
||||
def remove_single_char_word(text):
|
||||
"""
|
||||
Remove single character word from text
|
||||
Example: I am in a a home for two years => am in home for two years
|
||||
Args:
|
||||
text (str): text
|
||||
Returns:
|
||||
(str): text with single char removed
|
||||
"""
|
||||
words = text.split()
|
||||
|
||||
filter_words = [word for word in words if len(word) > 1 or word.isnumeric()]
|
||||
return " ".join(filter_words)
|
||||
|
||||
def seperate_english_characters(text):
|
||||
text = text.lower()
|
||||
res = re.findall(r'[a-z]+', text) # search for english
|
||||
for match in res:
|
||||
if match not in {'.',' '}:
|
||||
text = re.sub(match, " "+ match+ " ",text)
|
||||
text = re.sub('\s+', ' ', text)
|
||||
return text
|
||||
|
||||
def digit2num(text, dig2num=False):
|
||||
|
||||
""" This function is used to clean numbers"""
|
||||
|
||||
# search for numbers with spaces
|
||||
# 100 . 000 => 100.000
|
||||
|
||||
res = re.search('[0-9]+\s\.\s[0-9]+', text)
|
||||
if res != None:
|
||||
t = re.sub(r'\s', '', res[0])
|
||||
text = re.sub(res[0], t, text)
|
||||
|
||||
# seperate numbers glued with words
|
||||
# 3أشهر => 3 أشهر
|
||||
# من10الى15 => من 10 الى 15
|
||||
# pdb.set_trace()
|
||||
res = re.findall(r'[^\u0600-\u06FF\%\@a-z]+', text) # search for digits
|
||||
for match in res:
|
||||
if match not in {'.',' '}:
|
||||
text = re.sub(match, " "+ match+ " ",text)
|
||||
text = re.sub('\s+', ' ', text)
|
||||
|
||||
# transliterate numbers to digits
|
||||
# 13 => ثلاثة عشر
|
||||
|
||||
if dig2num == True:
|
||||
words = araby.tokenize(text)
|
||||
for i in range(len(words)):
|
||||
digit = re.sub(r'[\u0600-\u06FF]+', '', words[i])
|
||||
if digit.isnumeric():
|
||||
sub_word = re.sub(r'[^\u0600-\u06FF]+', '', words[i])
|
||||
if number.number2text(digit) != 'صفر':
|
||||
words[i] = sub_word + number.number2text(digit)
|
||||
else:
|
||||
pass
|
||||
|
||||
return " ".join(words)
|
||||
else:
|
||||
return text
|
||||
|
||||
def data_cleaning(text):
|
||||
# text = remove_special_words(text)
|
||||
text = remove_punctuations(text)
|
||||
text = remove_single_char_word(text)
|
||||
text = remove_diacritics(text)
|
||||
text = seperate_english_characters(text)
|
||||
text = remove_extra_space(text)
|
||||
text = remove_dot(text)
|
||||
text = east_to_west_num(text)
|
||||
text = digit2num(text, True)
|
||||
text = normalizeArabic(text)
|
||||
#text = re.sub(r'#\w{1,2}\b', '', text) # text = re.sub(r'#\w{1,3}\b', '', text)
|
||||
#text = remove_hashes(text)
|
||||
#text = normalizeArabic(text)
|
||||
return text
|
||||
|
||||
def main():
|
||||
input_file = sys.argv[1] # input transcription file with format <id> <text>
|
||||
to_BW = str(sys.argv[2]) # transform to BW True|False
|
||||
output_file=sys.argv[3] # output file name
|
||||
data = read_tsv(input_file)
|
||||
new_data = []
|
||||
for i in range(len(data)):
|
||||
|
||||
tokens = data[i][0].split()
|
||||
|
||||
tokens[1:] = data_cleaning(" ".join(tokens[1:])).split()
|
||||
#tokens = data_cleaning(" ".join(tokens)).split()
|
||||
if to_BW == "True":
|
||||
for i in range(len(tokens[1:])):
|
||||
tokens[i+1] = fromBuckWalter(tokens[i+1])
|
||||
new_data.append(" ".join(tokens))
|
||||
else:
|
||||
new_data.append(" ".join(tokens))
|
||||
|
||||
df = pd.DataFrame(data=new_data)
|
||||
df.to_csv(output_file, sep = '\n', header=False, index=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
egs/iwslt22_ta/ST/prepare.sh
Executable file
183
egs/iwslt22_ta/ST/prepare.sh
Executable file
@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=20
|
||||
stage=1
|
||||
stop_stage=4
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files.
|
||||
#
|
||||
# - $dl_dir/iwslt_ta
|
||||
#
|
||||
# You can download the data from
|
||||
#
|
||||
#
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
#
|
||||
# Note: iwslt_ta is not available for direct
|
||||
# download, "Download IWSLT Tunisian from LDC LDC2022E01. This script assumes you prepared the stm files"
|
||||
#"Check the instructions to prepare the stm files from the raw data here https://github.com/kevinduh/iwslt22-dialect"
|
||||
|
||||
dl_dir=$PWD/download
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
1000
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/iwslt_ta,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/iwslt_ta $dl_dir/iwslt_ta
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
fbank=data/fbank
|
||||
manifests=data/manifests
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare iwslt manifest"
|
||||
# We assume that you have downloaded the iwslt_ta corpus to $dl_dir/iwslt_ta
|
||||
# Also git clone https://github.com/kevinduh/iwslt22-dialect
|
||||
if [ ! -d "iwslt22-dialect" ]; then
|
||||
echo "Splits directory (iwslt22-dialect) does not exist"
|
||||
echo "Run git clone https://github.com/kevinduh/iwslt22-dialect"
|
||||
exit 1
|
||||
fi
|
||||
manifests=data/manifests
|
||||
mkdir -p $manifests
|
||||
|
||||
lhotse prepare iwslt_ta $dl_dir/iwslt_ta iwslt22-dialect data/manifests
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
mkdir -p $manifests
|
||||
lhotse prepare musan $dl_dir/musan $manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank features"
|
||||
mkdir -p ${fbank}
|
||||
python local/compute_fbank_gpu.py --num-splits 20
|
||||
|
||||
log "Combine features from train splits (may take ~1h)"
|
||||
if [ ! -f $manifests/cuts_train.jsonl.gz ]; then
|
||||
pieces=$(find $manifests -name "cuts_train_[0-9]*.jsonl.gz")
|
||||
lhotse combine $pieces $manifests/cuts_train.jsonl.gz
|
||||
fi
|
||||
gunzip -c $manifests/cuts_train.jsonl.gz | shuf | gzip -c > ${fbank}/cuts_train_shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p ${fbank}
|
||||
./local/compute_fbank_musan.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
lang_dir_src=data/lang_phone_src
|
||||
lang_dir_tgt=data/lang_phone_tgt
|
||||
|
||||
if [ ! -f download/lm/train_src/transcript_words.txt ] || [ ! -f download/lm/train_tgt/transcript_words.txt ]; then
|
||||
# export train text file to build grapheme lexicon
|
||||
log "Creating transcripts in download/lm/train from lhotse cuts"
|
||||
mkdir -p download/lm/train_src
|
||||
mkdir -p download/lm/train_tgt
|
||||
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --src-langdir download/lm/train_src --tgt-langdir download/lm/train_tgt
|
||||
fi
|
||||
mkdir -p $lang_dir_src
|
||||
mkdir -p $lang_dir_tgt
|
||||
|
||||
log "Prepare lexicon"
|
||||
|
||||
./local/prep_lexicon.sh download/lm/train_src download/lm/train_tgt
|
||||
python local/prepare_lexicon.py $dl_dir/lm/train_src/words.txt $dl_dir/lm/train_src/lexicon.txt
|
||||
python local/prepare_lexicon.py $dl_dir/lm/train_tgt/words.txt $dl_dir/lm/train_tgt/lexicon.txt
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/train_src/lexicon.txt |
|
||||
sort | uniq > $lang_dir_src/lexicon.txt
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/train_tgt/lexicon.txt |
|
||||
sort | uniq > $lang_dir_tgt/lexicon.txt
|
||||
|
||||
if [ ! -f $lang_dir_src/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir_src
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir_tgt/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir_tgt
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang"
|
||||
srctag=ta
|
||||
tgttag=en
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
src_lang_dir=data/lang_bpe_${srctag}_${vocab_size}
|
||||
tgt_lang_dir=data/lang_bpe_${tgttag}_${vocab_size}
|
||||
mkdir -p ${src_lang_dir}
|
||||
mkdir -p ${tgt_lang_dir}
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone_src/words.txt $src_lang_dir
|
||||
cp data/lang_phone_tgt/words.txt $tgt_lang_dir
|
||||
if [ ! -f $src_lang_dir/transcript_words.txt ] || [ ! -f $tgt_lang_dir/transcript_words.txt ]; then
|
||||
log "Generate data for ${srctag} and ${tgttag} BPE training from data/fbank/cuts_train_shuf.jsonl.gz"
|
||||
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --src-langdir ${src_lang_dir} --tgt-langdir ${tgt_lang_dir}
|
||||
fi
|
||||
for lang_dir in $src_lang_dir $tgt_lang_dir; do
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
done
|
||||
done
|
||||
fi
|
||||
396
egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_datamodule.py
Normal file
396
egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_datamodule.py
Normal file
@ -0,0 +1,396 @@
|
||||
# Copyright 2023 Amir Hussein
|
||||
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2Speech2textTranslationDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class IWSLTDialectSTDataModule:
|
||||
|
||||
"""
|
||||
DataModule for k2 ST experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2Speech2textTranslationDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=8,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2Speech2textTranslationDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz")
|
||||
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
|
||||
1809
egs/iwslt22_ta/ST/pruned_transducer_stateless5/conformer.py
Normal file
1809
egs/iwslt22_ta/ST/pruned_transducer_stateless5/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
948
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py
Executable file
948
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py
Executable file
@ -0,0 +1,948 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Johns Hopkins (authors: Amir Hussein)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless5/decode_st.py \
|
||||
--epoch 12 \
|
||||
--beam-size 20 \
|
||||
--avg 3 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp_st \
|
||||
--max-duration 400 \
|
||||
--decoding-method modified_beam_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--num-encoder-layers 12 \
|
||||
--dim-feedforward 1024 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--use-averaged-model true
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import pdb
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from lhotse.qa import validate_cut
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import IWSLTDialectSTDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_rnnlm_shallow_fusion,
|
||||
)
|
||||
from train_st import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_translations,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_ta_1000/bpe.model",
|
||||
help="Path to source data BPE model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpe-tgt-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_en_1000/bpe.model",
|
||||
help="Path to target data BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_ta_1000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-tgt-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_en_1000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
def remove_short_and_long_utt(c):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 0.5 or c.duration > 30.0:
|
||||
#logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
#)
|
||||
return False
|
||||
if c.supervisions == []:
|
||||
return False
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
# where T is the number of feature frames after subsampling
|
||||
# and S is the number of tokens in the utterance
|
||||
|
||||
# In ./conformer.py, the conv module uses the following expression
|
||||
# for subsamplin
|
||||
|
||||
return True
|
||||
|
||||
# def remove_seg(c):
|
||||
# if c.supervisions[0].id != 'fla_0102_1_0B_00107':
|
||||
# return True
|
||||
# else:
|
||||
# return False
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
texts_tgt = batch["supervisions"]["tgt_text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
#breakpoint()
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
|
||||
for cut_id, hyp_words, ref_text, ref_text_tgt in zip(cut_ids, hyps, texts, texts_tgt):
|
||||
ref_words = ref_text.split()
|
||||
ref_words_tgt = ref_text_tgt.split()
|
||||
this_batch.append((cut_id, ref_words, ref_words_tgt, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
#breakpoint()
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_translations(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_translations(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
IWSLTDialectSTDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_rnnlm_shallow_fusion",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_tgt_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
rnn_lm_model = None
|
||||
rnn_lm_scale = params.rnn_lm_scale
|
||||
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
assert params.rnn_lm_avg == 1
|
||||
|
||||
load_checkpoint(
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
rnn_lm_model,
|
||||
)
|
||||
rnn_lm_model.to(device)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_tgt_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_tgt_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
iwslt_ta = IWSLTDialectSTDataModule(args)
|
||||
|
||||
test_cuts = iwslt_ta.test_cuts()
|
||||
dev_cuts = iwslt_ta.dev_cuts()
|
||||
|
||||
# lev_test_cuts = lev_test_cuts.filter(remove_short_and_long_utt)
|
||||
# # lev_test_cuts = lev_test_cuts.filter(remove_seg)
|
||||
# gulf_test_cuts = gulf_test_cuts.filter(remove_short_and_long_utt)
|
||||
# egy_test_cuts = egy_test_cuts.filter(remove_short_and_long_utt)
|
||||
# egy_h5_cuts = egy_sup_cuts.filter(remove_short_and_long_utt)
|
||||
# egy_sup_cuts = egy_h5_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
test_dl = iwslt_ta.test_dataloaders(test_cuts)
|
||||
dev_dl = iwslt_ta.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_all_dl = [test_dl, dev_dl]
|
||||
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_all_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
rnnlm=rnn_lm_model,
|
||||
rnnlm_scale=rnn_lm_scale,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py
|
||||
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
|
||||
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/export.py
|
||||
1306
egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune.py
Executable file
1306
egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
1323
egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune_loadmodel.py
Executable file
1323
egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune_loadmodel.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
|
||||
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/model.py
|
||||
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
|
||||
352
egs/iwslt22_ta/ST/pruned_transducer_stateless5/pretrained.py
Executable file
352
egs/iwslt22_ta/ST/pruned_transducer_stateless5/pretrained.py
Executable file
@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless5/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py
Symbolic link
1
egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
|
||||
65
egs/iwslt22_ta/ST/pruned_transducer_stateless5/test_model.py
Executable file
65
egs/iwslt22_ta/ST/pruned_transducer_stateless5/test_model.py
Executable file
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model_1():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = 24
|
||||
params.dim_feedforward = 1536 # 384 * 4
|
||||
params.encoder_dim = 384
|
||||
model = get_transducer_model(params)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
|
||||
def test_model_M():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = 18
|
||||
params.dim_feedforward = 1024
|
||||
params.encoder_dim = 256
|
||||
params.nhead = 4
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = get_transducer_model(params)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
def main():
|
||||
# test_model_1()
|
||||
test_model_M()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user