mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update RNNLM training scripts (#720)
* Update RNNLM training scripts * Fix a typo * Fix CI
This commit is contained in:
parent
556c63fbb7
commit
2bca7032af
67
.github/workflows/run-ptb-rnn-lm.yml
vendored
Normal file
67
.github/workflows/run-ptb-rnn-lm.yml
vendored
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
name: run-ptb-rnn-lm-training
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_ptb_rnn_lm_training:
|
||||||
|
if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: ["3.8"]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf
|
||||||
|
|
||||||
|
- name: Prepare data
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
cd egs/ptb/LM
|
||||||
|
./prepare.sh
|
||||||
|
|
||||||
|
- name: Run training
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
cd egs/ptb/LM
|
||||||
|
./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2
|
||||||
|
|
||||||
|
- name: Upload pretrained models
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
|
with:
|
||||||
|
name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb
|
||||||
|
path: egs/ptb/LM/my-rnnlm-exp/
|
@ -89,6 +89,10 @@ def main():
|
|||||||
bos_id=-1,
|
bos_id=-1,
|
||||||
eos_id=-1,
|
eos_id=-1,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
print(f"{model_file} exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||||
|
|
||||||
|
@ -22,9 +22,9 @@ dl_dir=$PWD/download
|
|||||||
# if the array contains xxx, yyy
|
# if the array contains xxx, yyy
|
||||||
vocab_sizes=(
|
vocab_sizes=(
|
||||||
500
|
500
|
||||||
1000
|
# 1000
|
||||||
2000
|
# 2000
|
||||||
5000
|
# 5000
|
||||||
)
|
)
|
||||||
|
|
||||||
# All files generated by this script are saved in "data".
|
# All files generated by this script are saved in "data".
|
||||||
@ -42,11 +42,14 @@ log "dl_dir: $dl_dir"
|
|||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
log "Stage -1: Download data"
|
log "Stage -1: Download data"
|
||||||
|
|
||||||
|
# Caution: The downloaded data has already been normalized for LM training.
|
||||||
|
|
||||||
if [ ! -f $dl_dir/.complete ]; then
|
if [ ! -f $dl_dir/.complete ]; then
|
||||||
url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/
|
url=http://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data
|
||||||
wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt
|
wget --directory-prefix $dl_dir $url/ptb.train.txt
|
||||||
wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt
|
wget --directory-prefix $dl_dir $url/ptb.valid.txt
|
||||||
wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt
|
wget --directory-prefix $dl_dir $url/ptb.test.txt
|
||||||
touch $dl_dir/.complete
|
touch $dl_dir/.complete
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
@ -54,11 +57,15 @@ fi
|
|||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "Stage 0: Train BPE model"
|
log "Stage 0: Train BPE model"
|
||||||
|
|
||||||
|
# Caution: You have to use the same bpe model for training your acoustic model
|
||||||
|
# Caution: You have to use the same bpe model for training your acoustic model
|
||||||
|
# Caution: You have to use the same bpe model for training your acoustic model
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
out_dir=data/bpe_${vocab_size}
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
mkdir -p $out_dir
|
mkdir -p $lang_dir
|
||||||
./local/train_bpe_model.py \
|
./local/train_bpe_model.py \
|
||||||
--out-dir $out_dir \
|
--lang-dir $lang_dir \
|
||||||
--vocab-size $vocab_size \
|
--vocab-size $vocab_size \
|
||||||
--transcript $dl_dir/ptb.train.txt
|
--transcript $dl_dir/ptb.train.txt
|
||||||
done
|
done
|
||||||
@ -69,20 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
# Note: ptb.train.txt has already been normalized
|
# Note: ptb.train.txt has already been normalized
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
out_dir=data/bpe_${vocab_size}
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
|
out_dir=data/lm_training_bpe_${vocab_size}
|
||||||
mkdir -p $out_dir
|
mkdir -p $out_dir
|
||||||
./local/prepare_lm_training_data.py \
|
./local/prepare_lm_training_data.py \
|
||||||
--bpe-model $out_dir/bpe.model \
|
--bpe-model $lang_dir/bpe.model \
|
||||||
--lm-data $dl_dir/ptb.train.txt \
|
--lm-data $dl_dir/ptb.train.txt \
|
||||||
--lm-archive $out_dir/lm_data.pt
|
--lm-archive $out_dir/lm_data.pt
|
||||||
|
|
||||||
./local/prepare_lm_training_data.py \
|
./local/prepare_lm_training_data.py \
|
||||||
--bpe-model $out_dir/bpe.model \
|
--bpe-model $lang_dir/bpe.model \
|
||||||
--lm-data $dl_dir/ptb.valid.txt \
|
--lm-data $dl_dir/ptb.valid.txt \
|
||||||
--lm-archive $out_dir/lm_data-valid.pt
|
--lm-archive $out_dir/lm_data-valid.pt
|
||||||
|
|
||||||
./local/prepare_lm_training_data.py \
|
./local/prepare_lm_training_data.py \
|
||||||
--bpe-model $out_dir/bpe.model \
|
--bpe-model $lang_dir/bpe.model \
|
||||||
--lm-data $dl_dir/ptb.test.txt \
|
--lm-data $dl_dir/ptb.test.txt \
|
||||||
--lm-archive $out_dir/lm_data-test.pt
|
--lm-archive $out_dir/lm_data-test.pt
|
||||||
done
|
done
|
||||||
@ -98,7 +106,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
# in a sentence.
|
# in a sentence.
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
out_dir=data/bpe_${vocab_size}
|
out_dir=data/lm_training_bpe_${vocab_size}
|
||||||
mkdir -p $out_dir
|
mkdir -p $out_dir
|
||||||
./local/sort_lm_training_data.py \
|
./local/sort_lm_training_data.py \
|
||||||
--in-lm-data $out_dir/lm_data.pt \
|
--in-lm-data $out_dir/lm_data.pt \
|
||||||
|
1
egs/ptb/LM/rnn_lm
Symbolic link
1
egs/ptb/LM/rnn_lm
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/rnn_lm
|
67
egs/ptb/LM/train-rnn-lm.sh
Executable file
67
egs/ptb/LM/train-rnn-lm.sh
Executable file
@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Please run ./prepare.sh first
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
# Number of GPUs to use for training
|
||||||
|
world_size=1
|
||||||
|
|
||||||
|
# Number of epochs to train
|
||||||
|
num_epochs=20
|
||||||
|
|
||||||
|
# Use this epoch for computing ppl
|
||||||
|
use_epoch=19
|
||||||
|
|
||||||
|
# number of models to average for computing ppl
|
||||||
|
use_avg=2
|
||||||
|
|
||||||
|
exp_dir=./my-rnnlm-exp
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Training RNN LM"
|
||||||
|
|
||||||
|
./rnn_lm/train.py \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--num-epochs $num_epochs \
|
||||||
|
--world-size $world_size \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--vocab-size 500 \
|
||||||
|
\
|
||||||
|
--lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \
|
||||||
|
--lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \
|
||||||
|
\
|
||||||
|
--embedding-dim 800 \
|
||||||
|
--hidden-dim 200 \
|
||||||
|
--num-layers 2 \
|
||||||
|
--tie-weights false \
|
||||||
|
--batch-size 50
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Computing perplexity"
|
||||||
|
|
||||||
|
./rnn_lm/compute_perplexity.py \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--epoch $use_epoch \
|
||||||
|
--avg $use_avg \
|
||||||
|
--vocab-size 500 \
|
||||||
|
\
|
||||||
|
--lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \
|
||||||
|
\
|
||||||
|
--embedding-dim 800 \
|
||||||
|
--hidden-dim 200 \
|
||||||
|
--num-layers 2 \
|
||||||
|
--tie-weights false \
|
||||||
|
--batch-size 50
|
||||||
|
fi
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
./rnn_lm/compute_perplexity.py \
|
./rnn_lm/compute_perplexity.py \
|
||||||
--epoch 4 \
|
--epoch 4 \
|
||||||
--avg 2 \
|
--avg 2 \
|
||||||
--lm-data ./data/bpe_500/sorted_lm_data-test.pt
|
--lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Fangjun Kuang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -194,7 +194,7 @@ def get_dataloader(
|
|||||||
batch_size=params.batch_size,
|
batch_size=params.batch_size,
|
||||||
)
|
)
|
||||||
if is_distributed:
|
if is_distributed:
|
||||||
sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
|
sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
|
||||||
else:
|
else:
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ Usage:
|
|||||||
--use-fp16 0 \
|
--use-fp16 0 \
|
||||||
--embedding-dim 800 \
|
--embedding-dim 800 \
|
||||||
--hidden-dim 200 \
|
--hidden-dim 200 \
|
||||||
--num-layers 2\
|
--num-layers 2 \
|
||||||
--batch-size 400
|
--batch-size 400
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -83,7 +83,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=30,
|
||||||
help="Number of epochs to train.",
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -110,14 +110,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-fp16",
|
"--use-fp16",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch-size",
|
"--batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -165,7 +165,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tie-weights",
|
"--tie-weights",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="""True to share the weights between the input embedding layer and the
|
help="""True to share the weights between the input embedding layer and the
|
||||||
last output linear layer
|
last output linear layer
|
||||||
""",
|
""",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user