mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Remove multidataset from librispeech/pruned_transducer_stateless7 (#1105)
* Add People's Speech to multidataset * update * remove multi from librispeech
This commit is contained in:
parent
7a604057f9
commit
82f34a2388
@ -1,117 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
|
|
||||||
nj=16
|
|
||||||
stage=-1
|
|
||||||
stop_stage=100
|
|
||||||
|
|
||||||
# Split data/${lang}set to this number of pieces
|
|
||||||
# This is to avoid OOM during feature extraction.
|
|
||||||
num_splits=1000
|
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
|
||||||
# directories and files. If not, they will be downloaded
|
|
||||||
# by this script automatically.
|
|
||||||
#
|
|
||||||
# - $dl_dir/$release/$lang
|
|
||||||
# This directory contains the following files downloaded from
|
|
||||||
# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz
|
|
||||||
#
|
|
||||||
# - clips
|
|
||||||
# - dev.tsv
|
|
||||||
# - invalidated.tsv
|
|
||||||
# - other.tsv
|
|
||||||
# - reported.tsv
|
|
||||||
# - test.tsv
|
|
||||||
# - train.tsv
|
|
||||||
# - validated.tsv
|
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
|
||||||
release=cv-corpus-13.0-2023-03-09
|
|
||||||
lang=en
|
|
||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
|
||||||
|
|
||||||
# All files generated by this script are saved in "data/${lang}".
|
|
||||||
# You can safely remove "data/${lang}" and rerun this script to regenerate it.
|
|
||||||
mkdir -p data/${lang}
|
|
||||||
|
|
||||||
log() {
|
|
||||||
# This function is from espnet
|
|
||||||
local fname=${BASH_SOURCE[1]##*/}
|
|
||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
|
||||||
}
|
|
||||||
|
|
||||||
log "dl_dir: $dl_dir"
|
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|
||||||
log "Stage 0: Download data"
|
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/$release,
|
|
||||||
# you can create a symlink
|
|
||||||
#
|
|
||||||
# ln -sfv /path/to/$release $dl_dir/$release
|
|
||||||
#
|
|
||||||
if [ ! -d $dl_dir/$release/$lang/clips ]; then
|
|
||||||
lhotse download commonvoice --languages $lang --release $release $dl_dir
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|
||||||
log "Stage 1: Prepare CommonVoice manifest"
|
|
||||||
# We assume that you have downloaded the CommonVoice corpus
|
|
||||||
# to $dl_dir/$release
|
|
||||||
mkdir -p data/${lang}/manifests
|
|
||||||
if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then
|
|
||||||
lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests
|
|
||||||
touch data/${lang}/manifests/.cv-${lang}.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|
||||||
log "Stage 2: Preprocess CommonVoice manifest"
|
|
||||||
if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then
|
|
||||||
./local/preprocess_commonvoice.py --language $lang
|
|
||||||
touch data/${lang}/fbank/.preprocess_complete
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|
||||||
log "Stage 3: Compute fbank for dev and test subsets of CommonVoice"
|
|
||||||
mkdir -p data/${lang}/fbank
|
|
||||||
if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then
|
|
||||||
./local/compute_fbank_commonvoice_dev_test.py --language $lang
|
|
||||||
touch data/${lang}/fbank/.cv-${lang}_dev_test.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|
||||||
log "Stage 4: Split train subset into ${num_splits} pieces"
|
|
||||||
split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits}
|
|
||||||
if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then
|
|
||||||
lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir
|
|
||||||
touch $split_dir/.cv-${lang}_train_split.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|
||||||
log "Stage 5: Compute features for train subset of CommonVoice"
|
|
||||||
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
|
|
||||||
./local/compute_fbank_commonvoice_splits.py \
|
|
||||||
--num-workers $nj \
|
|
||||||
--batch-duration 600 \
|
|
||||||
--start 0 \
|
|
||||||
--num-splits $num_splits \
|
|
||||||
--language $lang
|
|
||||||
touch data/${lang}/fbank/.cv-${lang}_train.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Combine features for train"
|
|
||||||
if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then
|
|
||||||
pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz")
|
|
||||||
lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz
|
|
||||||
fi
|
|
||||||
fi
|
|
@ -1,159 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
|
|
||||||
nj=15
|
|
||||||
stage=-1
|
|
||||||
stop_stage=100
|
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
|
||||||
# directories and files. If not, they will be downloaded
|
|
||||||
# by this script automatically.
|
|
||||||
#
|
|
||||||
# - $dl_dir/GigaSpeech
|
|
||||||
# You can find audio, dict, GigaSpeech.json inside it.
|
|
||||||
# You can apply for the download credentials by following
|
|
||||||
# https://github.com/SpeechColab/GigaSpeech#download
|
|
||||||
|
|
||||||
# Number of hours for GigaSpeech subsets
|
|
||||||
# XL 10k hours
|
|
||||||
# L 2.5k hours
|
|
||||||
# M 1k hours
|
|
||||||
# S 250 hours
|
|
||||||
# XS 10 hours
|
|
||||||
# DEV 12 hours
|
|
||||||
# Test 40 hours
|
|
||||||
|
|
||||||
# Split XL subset to this number of pieces
|
|
||||||
# This is to avoid OOM during feature extraction.
|
|
||||||
num_splits=2000
|
|
||||||
# We use lazy split from lhotse.
|
|
||||||
# The XL subset (10k hours) contains 37956 cuts without speed perturbing.
|
|
||||||
# We want to split it into 2000 splits, so each split
|
|
||||||
# contains about 37956 / 2000 = 19 cuts. As a result, there will be 1998 splits.
|
|
||||||
chunk_size=19 # number of cuts in each split. The last split may contain fewer cuts.
|
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
|
||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
|
||||||
|
|
||||||
# All files generated by this script are saved in "data".
|
|
||||||
# You can safely remove "data" and rerun this script to regenerate it.
|
|
||||||
mkdir -p data
|
|
||||||
|
|
||||||
log() {
|
|
||||||
# This function is from espnet
|
|
||||||
local fname=${BASH_SOURCE[1]##*/}
|
|
||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
|
||||||
}
|
|
||||||
|
|
||||||
log "dl_dir: $dl_dir"
|
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|
||||||
log "Stage 0: Download data"
|
|
||||||
|
|
||||||
[ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech
|
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/GigaSpeech,
|
|
||||||
# you can create a symlink
|
|
||||||
#
|
|
||||||
# ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech
|
|
||||||
#
|
|
||||||
if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then
|
|
||||||
# Check credentials.
|
|
||||||
if [ ! -f $dl_dir/password ]; then
|
|
||||||
echo -n "$0: Please apply for the download credentials by following"
|
|
||||||
echo -n "https://github.com/SpeechColab/GigaSpeech#dataset-download"
|
|
||||||
echo " and save it to $dl_dir/password."
|
|
||||||
exit 1;
|
|
||||||
fi
|
|
||||||
PASSWORD=`cat $dl_dir/password 2>/dev/null`
|
|
||||||
if [ -z "$PASSWORD" ]; then
|
|
||||||
echo "$0: Error, $dl_dir/password is empty."
|
|
||||||
exit 1;
|
|
||||||
fi
|
|
||||||
PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1`
|
|
||||||
if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then
|
|
||||||
echo "$0: Error, invalid $dl_dir/password."
|
|
||||||
exit 1;
|
|
||||||
fi
|
|
||||||
# Download XL, DEV and TEST sets by default.
|
|
||||||
lhotse download gigaspeech \
|
|
||||||
--subset XL \
|
|
||||||
--subset L \
|
|
||||||
--subset M \
|
|
||||||
--subset S \
|
|
||||||
--subset XS \
|
|
||||||
--subset DEV \
|
|
||||||
--subset TEST \
|
|
||||||
--host tsinghua \
|
|
||||||
$dl_dir/password $dl_dir/GigaSpeech
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|
||||||
log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)"
|
|
||||||
# We assume that you have downloaded the GigaSpeech corpus
|
|
||||||
# to $dl_dir/GigaSpeech
|
|
||||||
if [ ! -f data/manifests/.gigaspeech.done ]; then
|
|
||||||
mkdir -p data/manifests
|
|
||||||
lhotse prepare gigaspeech \
|
|
||||||
--subset XL \
|
|
||||||
--subset L \
|
|
||||||
--subset M \
|
|
||||||
--subset S \
|
|
||||||
--subset XS \
|
|
||||||
--subset DEV \
|
|
||||||
--subset TEST \
|
|
||||||
-j $nj \
|
|
||||||
$dl_dir/GigaSpeech data/manifests
|
|
||||||
touch data/manifests/.gigaspeech.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|
||||||
log "Stage 2: Preprocess GigaSpeech manifest"
|
|
||||||
if [ ! -f data/fbank/.gigaspeech_preprocess.done ]; then
|
|
||||||
log "It may take 2 hours for this stage"
|
|
||||||
./local/preprocess_gigaspeech.py
|
|
||||||
touch data/fbank/.gigaspeech_preprocess.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|
||||||
log "Stage 3: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
|
|
||||||
if [ ! -f data/fbank/.gigaspeech_dev_test.done ]; then
|
|
||||||
./local/compute_fbank_gigaspeech_dev_test.py
|
|
||||||
touch data/fbank/.gigaspeech_dev_test.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|
||||||
log "Stage 4: Split XL subset into ${num_splits} pieces"
|
|
||||||
split_dir=data/fbank/gigaspeech_XL_split_${num_splits}
|
|
||||||
if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then
|
|
||||||
lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $chunk_size
|
|
||||||
touch $split_dir/.gigaspeech_XL_split.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|
||||||
log "Stage 5: Compute features for XL"
|
|
||||||
# Note: The script supports --start and --stop options.
|
|
||||||
# You can use several machines to compute the features in parallel.
|
|
||||||
if [ ! -f data/fbank/.gigaspeech_XL.done ]; then
|
|
||||||
./local/compute_fbank_gigaspeech_splits.py \
|
|
||||||
--num-workers $nj \
|
|
||||||
--batch-duration 600 \
|
|
||||||
--num-splits $num_splits
|
|
||||||
touch data/fbank/.gigaspeech_XL.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Combine features for XL (may take 15 hours)"
|
|
||||||
if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then
|
|
||||||
pieces=$(find data/fbank/gigaspeech_XL_split_${num_splits} -name "gigaspeech_cuts_XL.*.jsonl.gz")
|
|
||||||
lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz
|
|
||||||
fi
|
|
||||||
fi
|
|
@ -1,330 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
|
||||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
|
|
||||||
nj=16
|
|
||||||
stage=-1
|
|
||||||
stop_stage=100
|
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
|
||||||
# directories and files. If not, they will be downloaded
|
|
||||||
# by this script automatically.
|
|
||||||
#
|
|
||||||
# - $dl_dir/LibriSpeech
|
|
||||||
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
|
|
||||||
# You can download them from https://www.openslr.org/12
|
|
||||||
#
|
|
||||||
# - $dl_dir/lm
|
|
||||||
# This directory contains the following files downloaded from
|
|
||||||
# http://www.openslr.org/resources/11
|
|
||||||
#
|
|
||||||
# - 3-gram.pruned.1e-7.arpa.gz
|
|
||||||
# - 3-gram.pruned.1e-7.arpa
|
|
||||||
# - 4-gram.arpa.gz
|
|
||||||
# - 4-gram.arpa
|
|
||||||
# - librispeech-vocab.txt
|
|
||||||
# - librispeech-lexicon.txt
|
|
||||||
# - librispeech-lm-norm.txt.gz
|
|
||||||
#
|
|
||||||
# - $dl_dir/musan
|
|
||||||
# This directory contains the following directories downloaded from
|
|
||||||
# http://www.openslr.org/17/
|
|
||||||
#
|
|
||||||
# - music
|
|
||||||
# - noise
|
|
||||||
# - speech
|
|
||||||
|
|
||||||
# Split all dataset to this number of pieces and mix each dataset pieces
|
|
||||||
# into multidataset pieces with shuffling.
|
|
||||||
num_splits=1998
|
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
|
||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
|
||||||
|
|
||||||
# vocab size for sentence piece models.
|
|
||||||
# It will generate data/lang_bpe_xxx,
|
|
||||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
|
||||||
vocab_sizes=(
|
|
||||||
# 5000
|
|
||||||
# 2000
|
|
||||||
# 1000
|
|
||||||
500
|
|
||||||
)
|
|
||||||
|
|
||||||
# multidataset list.
|
|
||||||
# LibriSpeech and musan are required.
|
|
||||||
# The others are optional.
|
|
||||||
multidataset=(
|
|
||||||
"gigaspeech",
|
|
||||||
"commonvoice",
|
|
||||||
)
|
|
||||||
|
|
||||||
# All files generated by this script are saved in "data".
|
|
||||||
# You can safely remove "data" and rerun this script to regenerate it.
|
|
||||||
mkdir -p data
|
|
||||||
|
|
||||||
log() {
|
|
||||||
# This function is from espnet
|
|
||||||
local fname=${BASH_SOURCE[1]##*/}
|
|
||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
|
||||||
}
|
|
||||||
|
|
||||||
log "dl_dir: $dl_dir"
|
|
||||||
|
|
||||||
log "Dataset: LibriSpeech and musan"
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
|
||||||
log "Stage -1: Download LM"
|
|
||||||
mkdir -p $dl_dir/lm
|
|
||||||
if [ ! -e $dl_dir/lm/.done ]; then
|
|
||||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
|
||||||
touch $dl_dir/lm/.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|
||||||
log "Stage 0: Download data"
|
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/LibriSpeech,
|
|
||||||
# you can create a symlink
|
|
||||||
#
|
|
||||||
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
|
|
||||||
#
|
|
||||||
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
|
|
||||||
lhotse download librispeech --full $dl_dir
|
|
||||||
fi
|
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/musan,
|
|
||||||
# you can create a symlink
|
|
||||||
#
|
|
||||||
# ln -sfv /path/to/musan $dl_dir/
|
|
||||||
#
|
|
||||||
if [ ! -d $dl_dir/musan ]; then
|
|
||||||
lhotse download musan $dl_dir
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|
||||||
log "Stage 1: Prepare LibriSpeech manifest"
|
|
||||||
# We assume that you have downloaded the LibriSpeech corpus
|
|
||||||
# to $dl_dir/LibriSpeech
|
|
||||||
mkdir -p data/manifests
|
|
||||||
if [ ! -e data/manifests/.librispeech.done ]; then
|
|
||||||
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
|
|
||||||
touch data/manifests/.librispeech.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|
||||||
log "Stage 2: Prepare musan manifest"
|
|
||||||
# We assume that you have downloaded the musan corpus
|
|
||||||
# to data/musan
|
|
||||||
mkdir -p data/manifests
|
|
||||||
if [ ! -e data/manifests/.musan.done ]; then
|
|
||||||
lhotse prepare musan $dl_dir/musan data/manifests
|
|
||||||
touch data/manifests/.musan.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|
||||||
log "Stage 3: Compute fbank for librispeech"
|
|
||||||
mkdir -p data/fbank
|
|
||||||
if [ ! -e data/fbank/.librispeech.done ]; then
|
|
||||||
./local/compute_fbank_librispeech.py --perturb-speed False
|
|
||||||
touch data/fbank/.librispeech.done
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then
|
|
||||||
cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \
|
|
||||||
<(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \
|
|
||||||
<(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \
|
|
||||||
shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -e data/fbank/.librispeech-validated.done ]; then
|
|
||||||
log "Validating data/fbank for LibriSpeech"
|
|
||||||
parts=(
|
|
||||||
train-clean-100
|
|
||||||
train-clean-360
|
|
||||||
train-other-500
|
|
||||||
test-clean
|
|
||||||
test-other
|
|
||||||
dev-clean
|
|
||||||
dev-other
|
|
||||||
)
|
|
||||||
for part in ${parts[@]}; do
|
|
||||||
python3 ./local/validate_manifest.py \
|
|
||||||
data/fbank/librispeech_cuts_${part}.jsonl.gz
|
|
||||||
done
|
|
||||||
touch data/fbank/.librispeech-validated.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|
||||||
log "Stage 4: Compute fbank for musan"
|
|
||||||
mkdir -p data/fbank
|
|
||||||
if [ ! -e data/fbank/.musan.done ]; then
|
|
||||||
./local/compute_fbank_musan.py
|
|
||||||
touch data/fbank/.musan.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|
||||||
log "Stage 5: Prepare phone based lang"
|
|
||||||
lang_dir=data/lang_phone
|
|
||||||
mkdir -p $lang_dir
|
|
||||||
|
|
||||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
|
||||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
|
||||||
sort | uniq > $lang_dir/lexicon.txt
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
|
||||||
./local/prepare_lang.py --lang-dir $lang_dir
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L.fst ]; then
|
|
||||||
log "Converting L.pt to L.fst"
|
|
||||||
./shared/convert-k2-to-openfst.py \
|
|
||||||
--olabels aux_labels \
|
|
||||||
$lang_dir/L.pt \
|
|
||||||
$lang_dir/L.fst
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.fst ]; then
|
|
||||||
log "Converting L_disambig.pt to L_disambig.fst"
|
|
||||||
./shared/convert-k2-to-openfst.py \
|
|
||||||
--olabels aux_labels \
|
|
||||||
$lang_dir/L_disambig.pt \
|
|
||||||
$lang_dir/L_disambig.fst
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Prepare BPE based lang"
|
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
|
||||||
lang_dir=data/lang_bpe_${vocab_size}
|
|
||||||
mkdir -p $lang_dir
|
|
||||||
# We reuse words.txt from phone based lexicon
|
|
||||||
# so that the two can share G.pt later.
|
|
||||||
cp data/lang_phone/words.txt $lang_dir
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
|
||||||
log "Generate data for BPE training"
|
|
||||||
files=$(
|
|
||||||
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
|
||||||
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
|
|
||||||
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
|
|
||||||
)
|
|
||||||
for f in ${files[@]}; do
|
|
||||||
cat $f | cut -d " " -f 2-
|
|
||||||
done > $lang_dir/transcript_words.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/bpe.model ]; then
|
|
||||||
./local/train_bpe_model.py \
|
|
||||||
--lang-dir $lang_dir \
|
|
||||||
--vocab-size $vocab_size \
|
|
||||||
--transcript $lang_dir/transcript_words.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
|
||||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
|
||||||
|
|
||||||
log "Validating $lang_dir/lexicon.txt"
|
|
||||||
./local/validate_bpe_lexicon.py \
|
|
||||||
--lexicon $lang_dir/lexicon.txt \
|
|
||||||
--bpe-model $lang_dir/bpe.model
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L.fst ]; then
|
|
||||||
log "Converting L.pt to L.fst"
|
|
||||||
./shared/convert-k2-to-openfst.py \
|
|
||||||
--olabels aux_labels \
|
|
||||||
$lang_dir/L.pt \
|
|
||||||
$lang_dir/L.fst
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.fst ]; then
|
|
||||||
log "Converting L_disambig.pt to L_disambig.fst"
|
|
||||||
./shared/convert-k2-to-openfst.py \
|
|
||||||
--olabels aux_labels \
|
|
||||||
$lang_dir/L_disambig.pt \
|
|
||||||
$lang_dir/L_disambig.fst
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
|
||||||
log "Stage 7: Prepare G"
|
|
||||||
# We assume you have install kaldilm, if not, please install
|
|
||||||
# it using: pip install kaldilm
|
|
||||||
|
|
||||||
mkdir -p data/lm
|
|
||||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
|
||||||
# It is used in building HLG
|
|
||||||
python3 -m kaldilm \
|
|
||||||
--read-symbol-table="data/lang_phone/words.txt" \
|
|
||||||
--disambig-symbol='#0' \
|
|
||||||
--max-order=3 \
|
|
||||||
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
|
|
||||||
# It is used for LM rescoring
|
|
||||||
python3 -m kaldilm \
|
|
||||||
--read-symbol-table="data/lang_phone/words.txt" \
|
|
||||||
--disambig-symbol='#0' \
|
|
||||||
--max-order=4 \
|
|
||||||
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|
||||||
log "Stage 8: Compile HLG"
|
|
||||||
./local/compile_hlg.py --lang-dir data/lang_phone
|
|
||||||
|
|
||||||
# Note If ./local/compile_hlg.py throws OOM,
|
|
||||||
# please switch to the following command
|
|
||||||
#
|
|
||||||
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
|
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
|
||||||
lang_dir=data/lang_bpe_${vocab_size}
|
|
||||||
./local/compile_hlg.py --lang-dir $lang_dir
|
|
||||||
|
|
||||||
# Note If ./local/compile_hlg.py throws OOM,
|
|
||||||
# please switch to the following command
|
|
||||||
#
|
|
||||||
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Compile LG for RNN-T fast_beam_search decoding
|
|
||||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|
||||||
log "Stage 9: Compile LG"
|
|
||||||
./local/compile_lg.py --lang-dir data/lang_phone
|
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
|
||||||
lang_dir=data/lang_bpe_${vocab_size}
|
|
||||||
./local/compile_lg.py --lang-dir $lang_dir
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
|
||||||
log "Stage 10: Prepare the other datasets"
|
|
||||||
# GigaSpeech
|
|
||||||
if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then
|
|
||||||
log "Dataset: GigaSpeech"
|
|
||||||
./prepare_giga_speech.sh --stop_stage 5
|
|
||||||
fi
|
|
||||||
|
|
||||||
# CommonVoice
|
|
||||||
if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then
|
|
||||||
log "Dataset: CommonVoice"
|
|
||||||
./prepare_common_voice.sh
|
|
||||||
fi
|
|
||||||
fi
|
|
@ -1,77 +0,0 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import lhotse
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDataset:
|
|
||||||
def __init__(self, manifest_dir: str, cv_manifest_dir: str):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
manifest_dir:
|
|
||||||
It is expected to contain the following files:
|
|
||||||
|
|
||||||
- librispeech_cuts_train-all-shuf.jsonl.gz
|
|
||||||
- gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz
|
|
||||||
|
|
||||||
cv_manifest_dir:
|
|
||||||
It is expected to contain the following files:
|
|
||||||
|
|
||||||
- cv-en_cuts_train.jsonl.gz
|
|
||||||
"""
|
|
||||||
self.manifest_dir = Path(manifest_dir)
|
|
||||||
self.cv_manifest_dir = Path(cv_manifest_dir)
|
|
||||||
|
|
||||||
def train_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get multidataset train cuts")
|
|
||||||
|
|
||||||
# LibriSpeech
|
|
||||||
logging.info(f"Loading LibriSpeech in lazy mode")
|
|
||||||
librispeech_cuts = load_manifest_lazy(
|
|
||||||
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# GigaSpeech
|
|
||||||
filenames = glob.glob(
|
|
||||||
f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
|
|
||||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
|
||||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
|
||||||
|
|
||||||
sorted_filenames = [f[1] for f in idx_filenames]
|
|
||||||
|
|
||||||
logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode")
|
|
||||||
|
|
||||||
gigaspeech_cuts = lhotse.combine(
|
|
||||||
lhotse.load_manifest_lazy(p) for p in sorted_filenames
|
|
||||||
)
|
|
||||||
|
|
||||||
# CommonVoice
|
|
||||||
logging.info(f"Loading CommonVoice in lazy mode")
|
|
||||||
commonvoice_cuts = load_manifest_lazy(
|
|
||||||
self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts)
|
|
@ -66,7 +66,6 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from multidataset import MultiDataset
|
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
@ -376,13 +375,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-multidataset",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to use multidataset to train.",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1042,10 +1034,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
if params.use_multidataset:
|
|
||||||
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
|
|
||||||
train_cuts = multidataset.train_cuts()
|
|
||||||
else:
|
|
||||||
if params.mini_libri:
|
if params.mini_libri:
|
||||||
train_cuts = librispeech.train_clean_5_cuts()
|
train_cuts = librispeech.train_clean_5_cuts()
|
||||||
elif params.full_libri:
|
elif params.full_libri:
|
||||||
@ -1107,7 +1095,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.use_multidataset and not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
@ -62,20 +62,20 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from zipformer import Zipformer2
|
|
||||||
from scaling import ScheduledFloat
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from subsampling import Conv2dSubsampling
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
|
from scaling import ScheduledFloat
|
||||||
|
from subsampling import Conv2dSubsampling
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
@ -84,40 +84,38 @@ from icefall.checkpoint import (
|
|||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.hooks import register_inf_check_hooks
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
get_parameter_groups_with_lrs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_adjusted_batch_count(
|
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||||
params: AttributeDict) -> float:
|
|
||||||
# returns the number of batches we would have used so far if we had used the reference
|
# returns the number of batches we would have used so far if we had used the reference
|
||||||
# duration. This is for purposes of set_batch_count().
|
# duration. This is for purposes of set_batch_count().
|
||||||
return (params.batch_idx_train * (params.max_duration * params.world_size) /
|
return (
|
||||||
params.ref_duration)
|
params.batch_idx_train
|
||||||
|
* (params.max_duration * params.world_size)
|
||||||
|
/ params.ref_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(
|
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||||
model: Union[nn.Module, DDP], batch_count: float
|
|
||||||
) -> None:
|
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
# get underlying nn.Module
|
# get underlying nn.Module
|
||||||
model = model.module
|
model = model.module
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if hasattr(module, 'batch_count'):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
if hasattr(module, 'name'):
|
if hasattr(module, "name"):
|
||||||
module.name = name
|
module.name = name
|
||||||
|
|
||||||
|
|
||||||
@ -154,35 +152,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--encoder-dim",
|
"--encoder-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="192,256,384,512,384,256",
|
default="192,256,384,512,384,256",
|
||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--query-head-dim",
|
"--query-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="32",
|
default="32",
|
||||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
|
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--value-head-dim",
|
"--value-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="12",
|
default="12",
|
||||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
|
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pos-head-dim",
|
"--pos-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="4",
|
default="4",
|
||||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pos-dim",
|
"--pos-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default="48",
|
default="48",
|
||||||
help="Positional-encoding embedding dimension"
|
help="Positional-encoding embedding dimension",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -190,7 +188,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
type=str,
|
type=str,
|
||||||
default="192,192,256,256,256,192",
|
default="192,192,256,256,256,192",
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -230,7 +228,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
type=str,
|
type=str,
|
||||||
default="16,32,64,-1",
|
default="16,32,64,-1",
|
||||||
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
||||||
" Must be just -1 if --causal=False"
|
" Must be just -1 if --causal=False",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -239,7 +237,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default="64,128,256,-1",
|
default="64,128,256,-1",
|
||||||
help="Maximum left-contexts for causal training, measured in frames which will "
|
help="Maximum left-contexts for causal training, measured in frames which will "
|
||||||
"be converted to a number of chunks. If splitting into chunks, "
|
"be converted to a number of chunks. If splitting into chunks, "
|
||||||
"chunk left-context frames will be chosen randomly from this list; else not relevant."
|
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -313,10 +311,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr",
|
"--base-lr", type=float, default=0.045, help="The base learning rate."
|
||||||
type=float,
|
|
||||||
default=0.045,
|
|
||||||
help="The base learning rate."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -340,15 +335,14 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=600,
|
default=600,
|
||||||
help="Reference batch duration for purposes of adjusting batch counts for setting various "
|
help="Reference batch duration for purposes of adjusting batch counts for setting various "
|
||||||
"schedules inside the model"
|
"schedules inside the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -371,8 +365,7 @@ def get_parser():
|
|||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)"
|
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||||
"part.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -522,7 +515,7 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
|
|
||||||
def _to_int_tuple(s: str):
|
def _to_int_tuple(s: str):
|
||||||
return tuple(map(int, s.split(',')))
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||||
@ -537,7 +530,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
encoder_embed = Conv2dSubsampling(
|
encoder_embed = Conv2dSubsampling(
|
||||||
in_channels=params.feature_dim,
|
in_channels=params.feature_dim,
|
||||||
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
)
|
)
|
||||||
return encoder_embed
|
return encoder_embed
|
||||||
|
|
||||||
@ -596,7 +589,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
encoder_dim=int(max(params.encoder_dim.split(","))),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -745,11 +738,7 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = (
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
model.device
|
|
||||||
if isinstance(model, DDP)
|
|
||||||
else next(model.parameters()).device
|
|
||||||
)
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -779,27 +768,24 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = (
|
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss +
|
|
||||||
pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
(feature_lens // params.subsampling_factor).sum().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -895,7 +881,8 @@ def train_one_epoch(
|
|||||||
saved_bad_model = False
|
saved_bad_model = False
|
||||||
|
|
||||||
def save_bad_model(suffix: str = ""):
|
def save_bad_model(suffix: str = ""):
|
||||||
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
save_checkpoint_impl(
|
||||||
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
model=model,
|
model=model,
|
||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
@ -903,7 +890,8 @@ def train_one_epoch(
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
rank=0)
|
rank=0,
|
||||||
|
)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
@ -988,7 +976,9 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
raise RuntimeError(
|
||||||
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
@ -998,8 +988,8 @@ def train_one_epoch(
|
|||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, " +
|
f"lr: {cur_lr:.2e}, "
|
||||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -1010,9 +1000,7 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
@ -1029,7 +1017,9 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
@ -1103,12 +1093,10 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank],
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
get_parameter_groups_with_lrs(
|
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||||
model, lr=params.base_lr, include_names=True),
|
|
||||||
lr=params.base_lr, # should have no effect
|
lr=params.base_lr, # should have no effect
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
)
|
)
|
||||||
@ -1129,7 +1117,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -1153,9 +1141,9 @@ def run(rank, world_size, args):
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 20.0:
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
logging.warning(
|
# logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
)
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
# In pruned RNN-T, we require that T >= S
|
||||||
@ -1206,8 +1194,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16,
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
init_scale=1.0)
|
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1328,7 +1315,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user