Remove multidataset from librispeech/pruned_transducer_stateless7 (#1105)

* Add People's Speech to multidataset

* update

* remove multi from librispeech
This commit is contained in:
Yifan Yang 2023-06-01 18:45:20 +08:00 committed by GitHub
parent 7a604057f9
commit 82f34a2388
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 780 deletions

View File

@ -1,117 +0,0 @@
#!/usr/bin/env bash
set -eou pipefail
nj=16
stage=-1
stop_stage=100
# Split data/${lang}set to this number of pieces
# This is to avoid OOM during feature extraction.
num_splits=1000
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/$release/$lang
# This directory contains the following files downloaded from
# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz
#
# - clips
# - dev.tsv
# - invalidated.tsv
# - other.tsv
# - reported.tsv
# - test.tsv
# - train.tsv
# - validated.tsv
dl_dir=$PWD/download
release=cv-corpus-13.0-2023-03-09
lang=en
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data/${lang}".
# You can safely remove "data/${lang}" and rerun this script to regenerate it.
mkdir -p data/${lang}
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/$release,
# you can create a symlink
#
# ln -sfv /path/to/$release $dl_dir/$release
#
if [ ! -d $dl_dir/$release/$lang/clips ]; then
lhotse download commonvoice --languages $lang --release $release $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare CommonVoice manifest"
# We assume that you have downloaded the CommonVoice corpus
# to $dl_dir/$release
mkdir -p data/${lang}/manifests
if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then
lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests
touch data/${lang}/manifests/.cv-${lang}.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Preprocess CommonVoice manifest"
if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then
./local/preprocess_commonvoice.py --language $lang
touch data/${lang}/fbank/.preprocess_complete
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for dev and test subsets of CommonVoice"
mkdir -p data/${lang}/fbank
if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then
./local/compute_fbank_commonvoice_dev_test.py --language $lang
touch data/${lang}/fbank/.cv-${lang}_dev_test.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split train subset into ${num_splits} pieces"
split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits}
if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then
lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir
touch $split_dir/.cv-${lang}_train_split.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute features for train subset of CommonVoice"
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
./local/compute_fbank_commonvoice_splits.py \
--num-workers $nj \
--batch-duration 600 \
--start 0 \
--num-splits $num_splits \
--language $lang
touch data/${lang}/fbank/.cv-${lang}_train.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Combine features for train"
if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then
pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz")
lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz
fi
fi

View File

@ -1,159 +0,0 @@
#!/usr/bin/env bash
set -eou pipefail
nj=15
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/GigaSpeech
# You can find audio, dict, GigaSpeech.json inside it.
# You can apply for the download credentials by following
# https://github.com/SpeechColab/GigaSpeech#download
# Number of hours for GigaSpeech subsets
# XL 10k hours
# L 2.5k hours
# M 1k hours
# S 250 hours
# XS 10 hours
# DEV 12 hours
# Test 40 hours
# Split XL subset to this number of pieces
# This is to avoid OOM during feature extraction.
num_splits=2000
# We use lazy split from lhotse.
# The XL subset (10k hours) contains 37956 cuts without speed perturbing.
# We want to split it into 2000 splits, so each split
# contains about 37956 / 2000 = 19 cuts. As a result, there will be 1998 splits.
chunk_size=19 # number of cuts in each split. The last split may contain fewer cuts.
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
[ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech
# If you have pre-downloaded it to /path/to/GigaSpeech,
# you can create a symlink
#
# ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech
#
if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then
# Check credentials.
if [ ! -f $dl_dir/password ]; then
echo -n "$0: Please apply for the download credentials by following"
echo -n "https://github.com/SpeechColab/GigaSpeech#dataset-download"
echo " and save it to $dl_dir/password."
exit 1;
fi
PASSWORD=`cat $dl_dir/password 2>/dev/null`
if [ -z "$PASSWORD" ]; then
echo "$0: Error, $dl_dir/password is empty."
exit 1;
fi
PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1`
if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then
echo "$0: Error, invalid $dl_dir/password."
exit 1;
fi
# Download XL, DEV and TEST sets by default.
lhotse download gigaspeech \
--subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)"
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
if [ ! -f data/manifests/.gigaspeech.done ]; then
mkdir -p data/manifests
lhotse prepare gigaspeech \
--subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests
touch data/manifests/.gigaspeech.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Preprocess GigaSpeech manifest"
if [ ! -f data/fbank/.gigaspeech_preprocess.done ]; then
log "It may take 2 hours for this stage"
./local/preprocess_gigaspeech.py
touch data/fbank/.gigaspeech_preprocess.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
if [ ! -f data/fbank/.gigaspeech_dev_test.done ]; then
./local/compute_fbank_gigaspeech_dev_test.py
touch data/fbank/.gigaspeech_dev_test.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split XL subset into ${num_splits} pieces"
split_dir=data/fbank/gigaspeech_XL_split_${num_splits}
if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then
lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $chunk_size
touch $split_dir/.gigaspeech_XL_split.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute features for XL"
# Note: The script supports --start and --stop options.
# You can use several machines to compute the features in parallel.
if [ ! -f data/fbank/.gigaspeech_XL.done ]; then
./local/compute_fbank_gigaspeech_splits.py \
--num-workers $nj \
--batch-duration 600 \
--num-splits $num_splits
touch data/fbank/.gigaspeech_XL.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Combine features for XL (may take 15 hours)"
if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then
pieces=$(find data/fbank/gigaspeech_XL_split_${num_splits} -name "gigaspeech_cuts_XL.*.jsonl.gz")
lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz
fi
fi

View File

@ -1,330 +0,0 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=16
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriSpeech
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
#
# - 3-gram.pruned.1e-7.arpa.gz
# - 3-gram.pruned.1e-7.arpa
# - 4-gram.arpa.gz
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
# Split all dataset to this number of pieces and mix each dataset pieces
# into multidataset pieces with shuffling.
num_splits=1998
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
500
)
# multidataset list.
# LibriSpeech and musan are required.
# The others are optional.
multidataset=(
"gigaspeech",
"commonvoice",
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
log "Dataset: LibriSpeech and musan"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download LM"
mkdir -p $dl_dir/lm
if [ ! -e $dl_dir/lm/.done ]; then
./local/download_lm.py --out-dir=$dl_dir/lm
touch $dl_dir/lm/.done
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink
#
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
#
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
lhotse download librispeech --full $dl_dir
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LibriSpeech manifest"
# We assume that you have downloaded the LibriSpeech corpus
# to $dl_dir/LibriSpeech
mkdir -p data/manifests
if [ ! -e data/manifests/.librispeech.done ]; then
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
touch data/manifests/.librispeech.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for librispeech"
mkdir -p data/fbank
if [ ! -e data/fbank/.librispeech.done ]; then
./local/compute_fbank_librispeech.py --perturb-speed False
touch data/fbank/.librispeech.done
fi
if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz
fi
if [ ! -e data/fbank/.librispeech-validated.done ]; then
log "Validating data/fbank for LibriSpeech"
parts=(
train-clean-100
train-clean-360
train-other-500
test-clean
test-other
dev-clean
dev-other
)
for part in ${parts[@]}; do
python3 ./local/validate_manifest.py \
data/fbank/librispeech_cuts_${part}.jsonl.gz
done
touch data/fbank/.librispeech-validated.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
done
fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile LG"
./local/compile_lg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir
done
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare the other datasets"
# GigaSpeech
if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then
log "Dataset: GigaSpeech"
./prepare_giga_speech.sh --stop_stage 5
fi
# CommonVoice
if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then
log "Dataset: CommonVoice"
./prepare_common_voice.sh
fi
fi

View File

@ -1,77 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import re
from pathlib import Path
import lhotse
from lhotse import CutSet, load_manifest_lazy
class MultiDataset:
def __init__(self, manifest_dir: str, cv_manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files:
- librispeech_cuts_train-all-shuf.jsonl.gz
- gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz
cv_manifest_dir:
It is expected to contain the following files:
- cv-en_cuts_train.jsonl.gz
"""
self.manifest_dir = Path(manifest_dir)
self.cv_manifest_dir = Path(cv_manifest_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# LibriSpeech
logging.info(f"Loading LibriSpeech in lazy mode")
librispeech_cuts = load_manifest_lazy(
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
# GigaSpeech
filenames = glob.glob(
f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz"
)
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode")
gigaspeech_cuts = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
# CommonVoice
logging.info(f"Loading CommonVoice in lazy mode")
commonvoice_cuts = load_manifest_lazy(
self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz"
)
return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts)

View File

@ -66,7 +66,6 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from multidataset import MultiDataset
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
@ -376,13 +375,6 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--use-multidataset",
type=str2bool,
default=False,
help="Whether to use multidataset to train.",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -1042,10 +1034,6 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.use_multidataset:
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
train_cuts = multidataset.train_cuts()
else:
if params.mini_libri: if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts() train_cuts = librispeech.train_clean_5_cuts()
elif params.full_libri: elif params.full_libri:
@ -1107,7 +1095,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.use_multidataset and not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,

View File

@ -62,20 +62,20 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer2
from scaling import ScheduledFloat
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from subsampling import Conv2dSubsampling
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -84,40 +84,38 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
get_parameter_groups_with_lrs
) )
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def get_adjusted_batch_count( def get_adjusted_batch_count(params: AttributeDict) -> float:
params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference # returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count(). # duration. This is for purposes of set_batch_count().
return (params.batch_idx_train * (params.max_duration * params.world_size) / return (
params.ref_duration) params.batch_idx_train
* (params.max_duration * params.world_size)
/ params.ref_duration
)
def set_batch_count( def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
model: Union[nn.Module, DDP], batch_count: float
) -> None:
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
for name, module in model.named_modules(): for name, module in model.named_modules():
if hasattr(module, 'batch_count'): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
if hasattr(module, 'name'): if hasattr(module, "name"):
module.name = name module.name = name
@ -154,35 +152,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="192,256,384,512,384,256", default="192,256,384,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--query-head-dim", "--query-head-dim",
type=str, type=str,
default="32", default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list." help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--value-head-dim", "--value-head-dim",
type=str, type=str,
default="12", default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list." help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--pos-head-dim", "--pos-head-dim",
type=str, type=str,
default="4", default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--pos-dim", "--pos-dim",
type=int, type=int,
default="48", default="48",
help="Positional-encoding embedding dimension" help="Positional-encoding embedding dimension",
) )
parser.add_argument( parser.add_argument(
@ -190,7 +188,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str, type=str,
default="192,192,256,256,256,192", default="192,192,256,256,256,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim." "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
) )
parser.add_argument( parser.add_argument(
@ -230,7 +228,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str, type=str,
default="16,32,64,-1", default="16,32,64,-1",
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False" " Must be just -1 if --causal=False",
) )
parser.add_argument( parser.add_argument(
@ -239,7 +237,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="64,128,256,-1", default="64,128,256,-1",
help="Maximum left-contexts for causal training, measured in frames which will " help="Maximum left-contexts for causal training, measured in frames which will "
"be converted to a number of chunks. If splitting into chunks, " "be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant." "chunk left-context frames will be chosen randomly from this list; else not relevant.",
) )
@ -313,10 +311,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", "--base-lr", type=float, default=0.045, help="The base learning rate."
type=float,
default=0.045,
help="The base learning rate."
) )
parser.add_argument( parser.add_argument(
@ -340,15 +335,14 @@ def get_parser():
type=float, type=float,
default=600, default=600,
help="Reference batch duration for purposes of adjusting batch counts for setting various " help="Reference batch duration for purposes of adjusting batch counts for setting various "
"schedules inside the model" "schedules inside the model",
) )
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -371,8 +365,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -522,7 +515,7 @@ def get_params() -> AttributeDict:
def _to_int_tuple(s: str): def _to_int_tuple(s: str):
return tuple(map(int, s.split(','))) return tuple(map(int, s.split(",")))
def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_embed(params: AttributeDict) -> nn.Module:
@ -537,7 +530,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
encoder_embed = Conv2dSubsampling( encoder_embed = Conv2dSubsampling(
in_channels=params.feature_dim, in_channels=params.feature_dim,
out_channels=_to_int_tuple(params.encoder_dim)[0], out_channels=_to_int_tuple(params.encoder_dim)[0],
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
) )
return encoder_embed return encoder_embed
@ -596,7 +589,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(','))), encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -745,11 +738,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -779,27 +768,24 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss = ( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -895,7 +881,8 @@ def train_one_epoch(
saved_bad_model = False saved_bad_model = False
def save_bad_model(suffix: str = ""): def save_bad_model(suffix: str = ""):
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model, model=model,
model_avg=model_avg, model_avg=model_avg,
params=params, params=params,
@ -903,7 +890,8 @@ def train_one_epoch(
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler, scaler=scaler,
rank=0) rank=0,
)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
@ -988,7 +976,9 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())
@ -998,8 +988,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " + f"lr: {cur_lr:.2e}, "
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
) )
if tb_writer is not None: if tb_writer is not None:
@ -1010,9 +1000,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -1029,7 +1017,9 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
@ -1103,12 +1093,10 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], model = DDP(model, device_ids=[rank], find_unused_parameters=True)
find_unused_parameters=True)
optimizer = ScaledAdam( optimizer = ScaledAdam(
get_parameter_groups_with_lrs( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect lr=params.base_lr, # should have no effect
clipping_scale=2.0, clipping_scale=2.0,
) )
@ -1129,7 +1117,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1153,9 +1141,9 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 20.0: if c.duration < 1.0 or c.duration > 20.0:
logging.warning( # logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
) # )
return False return False
# In pruned RNN-T, we require that T >= S # In pruned RNN-T, we require that T >= S
@ -1206,8 +1194,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1328,7 +1315,9 @@ def scan_pessimistic_batches_for_oom(
) )
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main(): def main():