[Ready to be merged] Add RNN-LM to Conformer-CTC decoding (#439)

This commit is contained in:
ezerhouni 2022-06-23 13:37:03 +02:00 committed by GitHub
parent dc89b61b80
commit 0475d75d15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 2659 additions and 42 deletions

View File

@ -1299,17 +1299,18 @@ You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3x
#### 2021-11-09
The best WER, as of 2021-11-09, for the librispeech test dataset is below
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring):
The best WER, as of 2022-06-20, for the librispeech test dataset is below
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring + rnn lm rescoring):
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.42 | 5.73 |
| WER | 2.32 | 5.39 |
Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
| ngram_lm_scale | attention_scale |
|----------------|-----------------|
| 2.0 | 2.0 |
| ngram_lm_scale | attention_scale | rnn_lm_scale |
|----------------|-----------------|--------------|
| 0.3 | 2.1 | 2.2 |
To reproduce the above result, use the following commands for training:
@ -1330,11 +1331,27 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 0 \
--num-epochs 90
# Note: It trains for 90 epochs, but the best WER is at epoch-77.pt
# Train the RNN-LM
cd icefall
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./rnn_lm/train.py \
--exp-dir rnn_lm/exp_2048_3_tied \
--start-epoch 0 \
--world-size 4 \
--num-epochs 30 \
--use-fp16 1 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--batch-size 500 \
--tie-weights true
```
and the following command for decoding
```
rnn_dir=$(git rev-parse --show-toplevel)/icefall/rnn_lm
./conformer_ctc/decode.py \
--exp-dir conformer_ctc/exp_500_att0.8 \
--lang-dir data/lang_bpe_500 \
@ -1344,13 +1361,23 @@ and the following command for decoding
--num-paths 1000 \
--epoch 77 \
--avg 55 \
--method attention-decoder \
--nbest-scale 0.5
--nbest-scale 0.5 \
--rnn-lm-exp-dir ${rnn_dir}/exp_2048_3_tied \
--rnn-lm-epoch 29 \
--rnn-lm-avg 3 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights true \
--method rnn-lm
```
You can find the pre-trained model by visiting
You can find the Conformer-CTC pre-trained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09>
and the RNN-LM pre-trained model:
<https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
The tensorboard log for training is available at
<https://tensorboard.dev/experiment/hZDWrZfaSqOMqtW0NEfXKg/#scalars>

View File

@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -38,15 +38,19 @@ from icefall.decode import (
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -93,7 +97,9 @@ def get_parser():
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
@ -105,7 +111,7 @@ def get_parser():
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)
@ -116,7 +122,7 @@ def get_parser():
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
@ -139,11 +145,67 @@ def get_parser():
"--lm-dir",
type=str,
default="data/lm",
help="""The LM dir.
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
return parser
@ -173,6 +235,7 @@ def get_params() -> AttributeDict:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
@ -205,6 +268,8 @@ def decode_one_batch(
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
@ -330,6 +395,7 @@ def decode_one_batch(
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
@ -357,8 +423,6 @@ def decode_one_batch(
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
@ -370,6 +434,26 @@ def decode_one_batch(
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -388,6 +472,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
@ -405,6 +490,8 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
@ -442,6 +529,7 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
@ -490,7 +578,7 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
enable_log = False
else:
@ -566,6 +654,10 @@ def main():
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
@ -590,6 +682,7 @@ def main():
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
@ -621,7 +714,11 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
@ -648,20 +745,40 @@ def main():
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
@ -678,6 +795,7 @@ def main():
dl=test_dl,
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,

View File

@ -23,6 +23,7 @@ This file downloads the following LibriSpeech LM files:
- 4-gram.arpa.gz
- librispeech-vocab.txt
- librispeech-lexicon.txt
- librispeech-lm-norm.txt.gz
from http://www.openslr.org/resources/11
and save them in the user provided directory.
@ -61,6 +62,7 @@ def main(out_dir: str):
"4-gram.arpa.gz",
"librispeech-vocab.txt",
"librispeech-lexicon.txt",
"librispeech-lm-norm.txt.gz",
)
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):

View File

@ -0,0 +1,172 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
# Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes a `bpe.model` and a text file such as
./download/lm/librispeech-lm-norm.txt
and outputs the LM training data to a supplied directory such
as data/lm_training_bpe_500. The format is as follows:
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
representation of a dict with the following format:
'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32
containing the BPE representations of each word, indexed by
integer word ID. (These integer word IDS are present in
'lm_data'). The sentencepiece object can be used to turn the
words and BPE units into string form.
'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype
torch.int32 containing all the sentences, as word-ids (we don't
output the string form of this directly but it can be worked out
together with 'words' and the bpe.model).
'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing
number of BPE tokens of each sentence.
"""
import argparse
import logging
from pathlib import Path
import k2
import sentencepiece as spm
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="Input BPE model, e.g. data/bpe_500/bpe.model",
)
parser.add_argument(
"--lm-data",
type=str,
help="""Input LM training data as text, e.g.
download/pb.train.txt""",
)
parser.add_argument(
"--lm-archive",
type=str,
help="""Path to output archive, e.g. data/bpe_500/lm_data.pt;
look at the source of this script to see the format.""",
)
return parser.parse_args()
def main():
args = get_args()
if Path(args.lm_archive).exists():
logging.warning(f"{args.lm_archive} exists - skipping")
return
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
# word2index is a dictionary from words to integer ids. No need to reserve
# space for epsilon, etc.; the words are just used as a convenient way to
# compress the sequences of BPE pieces.
word2index = dict()
word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces.
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
if "librispeech-lm-norm" in args.lm_data:
num_lines_in_total = 40418261.0
step = 5000000
elif "valid" in args.lm_data:
num_lines_in_total = 5567.0
step = 3000
elif "test" in args.lm_data:
num_lines_in_total = 5559.0
step = 3000
else:
num_lines_in_total = None
step = None
processed = 0
with open(args.lm_data) as f:
while True:
line = f.readline()
if line == "":
break
if step and processed % step == 0:
logging.info(
f"Processed number of lines: {processed} "
f"({processed/num_lines_in_total*100: .3f}%)"
)
processed += 1
line_words = line.split()
for w in line_words:
if w not in word2index:
w_bpe = sp.encode(w)
word2index[w] = len(word2bpe)
word2bpe.append(w_bpe)
sentences.append([word2index[w] for w in line_words])
logging.info("Constructing ragged tensors")
words = k2.ragged.RaggedTensor(word2bpe)
sentences = k2.ragged.RaggedTensor(sentences)
output = dict(words=words, sentences=sentences)
num_sentences = sentences.dim0
logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}")
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
if step and i % step == 0:
logging.info(
f"Processed number of lines: {i} "
f"({i/num_sentences*100: .3f}%)"
)
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
output["sentence_lengths"] = torch.tensor(
sentence_lengths, dtype=torch.int32
)
torch.save(output, args.lm_archive)
logging.info(f"Saved to {args.lm_archive}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../ptb/LM/local/sort_lm_training_data.py

View File

@ -38,7 +38,6 @@ def get_args():
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the training corpus: transcript_words.txt.
The generated bpe.model is saved to this directory.
""",
)

View File

@ -24,6 +24,7 @@ stop_stage=100
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
@ -40,9 +41,9 @@ dl_dir=$PWD/download
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
5000
2000
1000
500
)
@ -278,3 +279,99 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
./local/compile_lg.py --lang-dir $lang_dir
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Generate LM training data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
lang_dir=data/lang_bpe_${vocab_size}
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $dl_dir/lm/librispeech-lm-norm.txt \
--lm-archive $out_dir/lm_data.pt
done
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Generate LM validation data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/valid.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/valid.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/valid.txt \
--lm-archive $out_dir/lm_data-valid.pt
done
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Generate LM test data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/test.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/test.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/test.txt \
--lm-archive $out_dir/lm_data-test.pt
done
fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Sort LM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of BPE tokens
# in a sentence.
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-valid.pt \
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
--out-statistics $out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-test.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
done
fi

18
egs/ptb/LM/README.md Normal file
View File

@ -0,0 +1,18 @@
## Description
(Note: the experiments here are only about language modeling)
ptb is short for Penn Treebank.
About the Penn Treebank corpus:
- This corpus is free for research purposes
- ptb.train.txt: train set
- ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training)
- ptb.test.txt: test set for reporting perplexity
You can download the dataset from one of the following URLs:
- https://github.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage
- http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
- https://deepai.org/dataset/penn-treebank

View File

@ -0,0 +1 @@
../../librispeech/ASR/local/prepare_lm_training_data.py

View File

@ -0,0 +1,143 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file takes as input the filename of LM training data
generated by ./local/prepare_lm_training_data.py and sorts
it by sentence length.
Sentence length equals to the number of BPE tokens in a sentence.
"""
import argparse
import logging
from pathlib import Path
import k2
import numpy as np
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--in-lm-data",
type=str,
help="Input LM training data, e.g., data/bpe_500/lm_data.pt",
)
parser.add_argument(
"--out-lm-data",
type=str,
help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt",
)
parser.add_argument(
"--out-statistics",
type=str,
help="Statistics about LM training data., data/bpe_500/statistics.txt",
)
return parser.parse_args()
def main():
args = get_args()
in_lm_data = Path(args.in_lm_data)
out_lm_data = Path(args.out_lm_data)
assert in_lm_data.is_file(), f"{in_lm_data}"
if out_lm_data.is_file():
logging.warning(f"{out_lm_data} exists - skipping")
return
data = torch.load(in_lm_data)
words2bpe = data["words"]
sentences = data["sentences"]
sentence_lengths = data["sentence_lengths"]
num_sentences = sentences.dim0
assert num_sentences == sentence_lengths.numel(), (
num_sentences,
sentence_lengths.numel(),
)
indices = torch.argsort(sentence_lengths, descending=True)
sorted_sentences = sentences[indices.to(torch.int32)]
sorted_sentence_lengths = sentence_lengths[indices]
# Check that sentences are ordered by length
assert num_sentences == sorted_sentences.dim0, (
num_sentences,
sorted_sentences.dim0,
)
cur = None
for i in range(num_sentences):
word_ids = sorted_sentences[i]
token_ids = words2bpe[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
if cur is not None:
assert cur >= token_ids.numel(), (cur, token_ids.numel())
cur = token_ids.numel()
assert cur == sorted_sentence_lengths[i]
data["sentences"] = sorted_sentences
data["sentence_lengths"] = sorted_sentence_lengths
torch.save(data, args.out_lm_data)
logging.info(f"Saved to {args.out_lm_data}")
statistics = Path(args.out_statistics)
# Write statistics
num_words = sorted_sentences.numel()
num_tokens = sentence_lengths.sum().item()
max_sentence_length = sentence_lengths[indices[0]]
min_sentence_length = sentence_lengths[indices[-1]]
step = 10
hist, bins = np.histogram(
sentence_lengths.numpy(),
bins=np.arange(1, max_sentence_length + step, step),
)
histogram = np.stack((bins[:-1], hist)).transpose()
with open(statistics, "w") as f:
f.write(f"num_sentences: {num_sentences}\n")
f.write(f"num_words: {num_words}\n")
f.write(f"num_tokens: {num_tokens}\n")
f.write(f"max_sentence_length: {max_sentence_length}\n")
f.write(f"min_sentence_length: {min_sentence_length}\n")
f.write("histogram:\n")
f.write(" bin count percent\n")
for row in histogram:
f.write(
f"{int(row[0]):>5} {int(row[1]):>5} "
f"{100.*row[1]/num_sentences:.3f}%\n"
)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import sentencepiece as spm
import torch
def main():
lm_training_data = Path("./data/bpe_500/lm_data.pt")
bpe_model = Path("./data/bpe_500/bpe.model")
if not lm_training_data.exists():
logging.warning(f"{lm_training_data} does not exist - skipping")
return
if not bpe_model.exists():
logging.warning(f"{bpe_model} does not exist - skipping")
return
sp = spm.SentencePieceProcessor()
sp.load(str(bpe_model))
data = torch.load(lm_training_data)
words2bpe = data["words"]
sentences = data["sentences"]
ss = []
unk = sp.decode(sp.unk_id()).strip()
for i in range(10):
s = sp.decode(words2bpe[sentences[i]].values.tolist())
s = s.replace(unk, "<unk>")
ss.append(s)
for s in ss:
print(s)
# You can compare the output with the first 10 lines of ptb.train.txt
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../librispeech/ASR/local/train_bpe_model.py

115
egs/ptb/LM/prepare.sh Executable file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env bash
set -eou pipefail
nj=15
stage=-1
stop_stage=100
dl_dir=$PWD/download
# The following files will be downloaded to $dl_dir
# - ptb.train.txt
# - ptb.valid.txt
# - ptb.test.txt
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/bpe_xxx, data/bpe_yyy
# if the array contains xxx, yyy
vocab_sizes=(
500
1000
2000
5000
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
mkdir -p $dl_dir
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download data"
if [ ! -f $dl_dir/.complete ]; then
url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/
wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt
wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt
wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt
touch $dl_dir/.complete
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train BPE model"
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/bpe_${vocab_size}
mkdir -p $out_dir
./local/train_bpe_model.py \
--out-dir $out_dir \
--vocab-size $vocab_size \
--transcript $dl_dir/ptb.train.txt
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Generate LM training data"
# Note: ptb.train.txt has already been normalized
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/bpe_${vocab_size}
mkdir -p $out_dir
./local/prepare_lm_training_data.py \
--bpe-model $out_dir/bpe.model \
--lm-data $dl_dir/ptb.train.txt \
--lm-archive $out_dir/lm_data.pt
./local/prepare_lm_training_data.py \
--bpe-model $out_dir/bpe.model \
--lm-data $dl_dir/ptb.valid.txt \
--lm-archive $out_dir/lm_data-valid.pt
./local/prepare_lm_training_data.py \
--bpe-model $out_dir/bpe.model \
--lm-data $dl_dir/ptb.test.txt \
--lm-archive $out_dir/lm_data-test.pt
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Sort LM training data"
# Sort LM training data generated in stage 1
# by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of BPE tokens
# in a sentence.
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/bpe_${vocab_size}
mkdir -p $out_dir
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-valid.pt \
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
--out-statistics $out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-test.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
done
fi

1
egs/ptb/LM/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -20,7 +20,34 @@ from typing import Dict, List, Optional, Union
import k2
import torch
from icefall.utils import get_texts
from icefall.utils import add_eos, add_sos, get_texts
DEFAULT_LM_SCALE = [
0.01,
0.05,
0.08,
0.1,
0.3,
0.5,
0.6,
0.7,
0.9,
1.0,
1.1,
1.2,
1.3,
1.5,
1.7,
1.9,
2.0,
2.1,
2.2,
2.3,
2.5,
3.0,
4.0,
5.0,
]
def _intersect_device(
@ -952,3 +979,161 @@ def rescore_with_attention_decoder(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path
return ans
def rescore_with_rnn_lm(
lattice: k2.Fsa,
num_paths: int,
rnn_lm_model: torch.nn.Module,
model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
blank_id: int,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
rnn_lm_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
memory_key_padding_mask:
The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
rnn_lm_scale:
Optional. It specifies the scale for RNN LM scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
# The shape of memory_key_padding_mask is (N, T), so we
# use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
token_ids = tokens.tolist()
if len(token_ids) == 0:
print("Warning: rescore_with_attention_decoder(): empty token-ids")
return None
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
# Now for RNN LM
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64)
y_tokens = y_tokens.to(torch.int64)
sentence_lengths = sentence_lengths.to(torch.int64)
rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths)
assert rnn_lm_nll.ndim == 2
assert rnn_lm_nll.shape[0] == len(token_ids)
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
ngram_lm_scale_list = DEFAULT_LM_SCALE
attention_scale_list = DEFAULT_LM_SCALE
rnn_lm_scale_list = DEFAULT_LM_SCALE
if ngram_lm_scale:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale:
attention_scale_list = [attention_scale]
if rnn_lm_scale:
rnn_lm_scale_list = [rnn_lm_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
for r_scale in rnn_lm_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
+ r_scale * rnn_lm_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa
ans[key] = best_path
return ans

View File

@ -21,14 +21,46 @@ import torch
from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
"""
rank and world_size are used only if use_ddp_launch is False.
"""
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
if use_ddp_launch is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
else:
dist.init_process_group("nccl")
def cleanup_dist():
dist.destroy_process_group()
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.rank()
else:
return 1
def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))

View File

@ -0,0 +1,237 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/compute_perplexity.py \
--epoch 4 \
--avg 2 \
--lm-data ./data/bpe_500/sorted_lm_data-test.pt
"""
import argparse
import logging
import math
from pathlib import Path
import torch
from dataset import get_dataloader
from model import RnnLmModel
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=49,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="The experiment dir",
)
parser.add_argument(
"--lm-data",
type=str,
help="Path to the LM test data for computing perplexity",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of RNN layers the model",
)
parser.add_argument(
"--max-sent-len",
type=int,
default=100,
help="Number of RNN layers the model",
)
parser.add_argument(
"--sos-id",
type=int,
default=1,
help="SOS ID",
)
parser.add_argument(
"--eos-id",
type=int,
default=1,
help="EOS ID",
)
parser.add_argument(
"--blank-id",
type=int,
default=0,
help="Blank ID",
)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lm_data = Path(args.lm_data)
params = AttributeDict(vars(args))
setup_logger(f"{params.exp_dir}/log-ppl/")
logging.info("Computing perplexity started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(
[p.numel() for p in model.parameters() if p.requires_grad]
)
logging.info(f"Number of model parameters: {num_param}")
logging.info(
f"Number of model parameters (requires_grad): "
f"{num_param_requires_grad} "
f"({num_param_requires_grad/num_param_requires_grad*100}%)"
)
logging.info(f"Loading LM test data from {params.lm_data}")
test_dl = get_dataloader(
filename=params.lm_data,
is_distributed=False,
params=params,
)
tot_loss = 0.0
num_tokens = 0
num_sentences = 0
for batch_idx, batch in enumerate(test_dl):
x, y, sentence_lengths = batch
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum().cpu().item()
tot_loss += loss
num_tokens += sentence_lengths.sum().cpu().item()
num_sentences += x.size(0)
ppl = math.exp(tot_loss / num_tokens)
logging.info(
f"total nll: {tot_loss}, num tokens: {num_tokens}, "
f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

218
icefall/rnn_lm/dataset.py Normal file
View File

@ -0,0 +1,218 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import k2
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from icefall.utils import AttributeDict, add_eos, add_sos
class LmDataset(torch.utils.data.Dataset):
def __init__(
self,
sentences: k2.RaggedTensor,
words: k2.RaggedTensor,
sentence_lengths: torch.Tensor,
max_sent_len: int,
batch_size: int,
):
"""
Args:
sentences:
A ragged tensor of dtype torch.int32 with 2 axes [sentence][word].
words:
A ragged tensor of dtype torch.int32 with 2 axes [word][token].
sentence_lengths:
A 1-D tensor of dtype torch.int32 containing number of tokens
of each sentence.
max_sent_len:
Maximum sentence length. It is used to change the batch size
dynamically. In general, we try to keep the product of
"max_sent_len in a batch" and "num_of_sent in a batch" being
a constant.
batch_size:
The expected batch size. It is changed dynamically according
to the "max_sent_len".
See `../local/prepare_lm_training_data.py` for how `sentences` and
`words` are generated. We assume that `sentences` are sorted by length.
See `../local/sort_lm_training_data.py`.
"""
super().__init__()
self.sentences = sentences
self.words = words
sentence_lengths = sentence_lengths.tolist()
assert batch_size > 0, batch_size
assert max_sent_len > 1, max_sent_len
batch_indexes = []
num_sentences = sentences.dim0
cur = 0
while cur < num_sentences:
sz = sentence_lengths[cur] // max_sent_len + 1
# Assume the current sentence has 3 * max_sent_len tokens,
# in the worst case, the subsequent sentences also have
# this number of tokens, we should reduce the batch size
# so that this batch will not contain too many tokens
actual_batch_size = batch_size // sz + 1
actual_batch_size = min(actual_batch_size, batch_size)
end = cur + actual_batch_size
end = min(end, num_sentences)
this_batch_indexes = torch.arange(cur, end).tolist()
batch_indexes.append(this_batch_indexes)
cur = end
assert batch_indexes[-1][-1] == num_sentences - 1
self.batch_indexes = k2.RaggedTensor(batch_indexes)
def __len__(self) -> int:
"""Return number of batches in this dataset"""
return self.batch_indexes.dim0
def __getitem__(self, i: int) -> k2.RaggedTensor:
"""Get the i'th batch in this dataset
Return a ragged tensor with 2 axes [sentence][token].
"""
assert 0 <= i < len(self), i
# indexes is a 1-D tensor containing sentence indexes
indexes = self.batch_indexes[i]
# sentence_words is a ragged tensor with 2 axes
# [sentence][word]
sentence_words = self.sentences[indexes]
# in case indexes contains only 1 entry, the returned
# sentence_words is a 1-D tensor, we have to convert
# it to a ragged tensor
if isinstance(sentence_words, torch.Tensor):
sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0))
# sentence_word_tokens is a ragged tensor with 3 axes
# [sentence][word][token]
sentence_word_tokens = self.words.index(sentence_words)
assert sentence_word_tokens.num_axes == 3
sentence_tokens = sentence_word_tokens.remove_axis(1)
return sentence_tokens
class LmDatasetCollate:
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
"""
Args:
sos_id:
Token ID of the SOS symbol.
eos_id:
Token ID of the EOS symbol.
blank_id:
Token ID of the blank symbol.
"""
self.sos_id = sos_id
self.eos_id = eos_id
self.blank_id = blank_id
def __call__(
self, batch: List[k2.RaggedTensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return a tuple containing 3 tensors:
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence starting with `self.sos_id`. It is padded to
the max sentence length with `self.blank_id`.
- y, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence ending with `self.eos_id` before padding.
Then it is padded to the max sentence length with
`self.blank_id`.
- lengths, a 2-D tensor of dtype torch.int32, containing the number of
tokens of each sentence before padding.
"""
# The batching stuff has already been done in LmDataset
assert len(batch) == 1
sentence_tokens = batch[0]
row_splits = sentence_tokens.shape.row_splits(1)
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
x = sentence_tokens_with_sos.pad(
mode="constant", padding_value=self.blank_id
)
y = sentence_tokens_with_eos.pad(
mode="constant", padding_value=self.blank_id
)
sentence_token_lengths += 1 # plus 1 since we added a SOS
return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
def get_dataloader(
filename: str,
is_distributed: bool,
params: AttributeDict,
) -> torch.utils.data.DataLoader:
"""Get dataloader for LM training.
Args:
filename:
Path to the file containing LM data. The file is assumed to
be generated by `../local/sort_lm_training_data.py`.
is_distributed:
True if using DDP training. False otherwise.
params:
Set `get_params()` from `rnn_lm/train.py`
Returns:
Return a dataloader containing the LM data.
"""
lm_data = torch.load(filename)
words = lm_data["words"]
sentences = lm_data["sentences"]
sentence_lengths = lm_data["sentence_lengths"]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=params.max_sent_len,
batch_size=params.batch_size,
)
if is_distributed:
sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
else:
sampler = None
collate_fn = LmDatasetCollate(
sos_id=params.sos_id,
eos_id=params.eos_id,
blank_id=params.blank_id,
)
dataloader = DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=sampler is None,
)
return dataloader

167
icefall/rnn_lm/export.py Normal file
View File

@ -0,0 +1,167 @@
#!/usr/bin/env python3
#
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from model import RnnLmModel
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, load_averaged_model, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = AttributeDict({})
params.update(vars(args))
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

120
icefall/rnn_lm/model.py Normal file
View File

@ -0,0 +1,120 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
class RnnLmModel(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
hidden_dim: int,
num_layers: int,
tie_weights: bool = False,
):
"""
Args:
vocab_size:
Vocabulary size of BPE model.
embedding_dim:
Input embedding dimension.
hidden_dim:
Hidden dimension of RNN layers.
num_layers:
Number of RNN layers.
tie_weights:
True to share the weights between the input embedding layer and the
last output linear layer. See https://arxiv.org/abs/1608.05859
and https://arxiv.org/abs/1611.01462
"""
super().__init__()
self.input_embedding = torch.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
)
self.rnn = torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
)
self.output_linear = torch.nn.Linear(
in_features=hidden_dim, out_features=vocab_size
)
self.vocab_size = vocab_size
if tie_weights:
logging.info("Tying weights")
assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim)
self.output_linear.weight = self.input_embedding.weight
else:
logging.info("Not tying weights")
def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor with shape (N, L). Each row
contains token IDs for a sentence and starts with the SOS token.
y:
A shifted version of `x` and with EOS appended.
lengths:
A 1-D tensor of shape (N,). It contains the sentence lengths
before padding.
Returns:
Return a 2-D tensor of shape (N, L) containing negative log-likelihood
loss values. Note: Loss values for padding positions are set to 0.
"""
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
assert lengths.ndim == 1, lengths.ndim
assert x.shape == y.shape, (x.shape, y.shape)
batch_size = x.size(0)
assert lengths.size(0) == batch_size, (lengths.size(0), batch_size)
# embedding is of shape (N, L, embedding_dim)
embedding = self.input_embedding(x)
# Note: We use batch_first==True
rnn_out, _ = self.rnn(embedding)
logits = self.output_linear(rnn_out)
# Note: No need to use `log_softmax()` here
# since F.cross_entropy() expects unnormalized probabilities
# nll_loss is of shape (N*L,)
# nll -> negative log-likelihood
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
# Set loss values for padding positions to 0
mask = make_pad_mask(lengths).reshape(-1)
nll_loss.masked_fill_(mask, 0)
nll_loss = nll_loss.reshape(batch_size, -1)
return nll_loss

71
icefall/rnn_lm/test_dataset.py Executable file
View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
from rnn_lm.dataset import LmDataset, LmDatasetCollate
def main():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=collate_fn
)
for i in dataloader:
print(i)
# I've checked the output manually; the output is as expected.
if __name__ == "__main__":
main()

View File

@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import k2
import torch
import torch.multiprocessing as mp
from rnn_lm.dataset import LmDataset, LmDatasetCollate
from torch import distributed as dist
def generate_data():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
return sentences, words, sentence_lengths
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12352"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
sentences, words, sentence_lengths = generate_data()
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=False
)
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=False,
)
for i in dataloader:
print(f"rank: {rank}", i)
dist.destroy_process_group()
def main():
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

69
icefall/rnn_lm/test_model.py Executable file
View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from rnn_lm.model import RnnLmModel
def test_rnn_lm_model():
vocab_size = 4
model = RnnLmModel(
vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2
)
x = torch.tensor(
[
[1, 3, 2, 2],
[1, 2, 2, 0],
[1, 2, 0, 0],
]
)
y = torch.tensor(
[
[3, 2, 2, 1],
[2, 2, 1, 0],
[2, 1, 0, 0],
]
)
lengths = torch.tensor([4, 3, 2])
nll_loss = model(x, y, lengths)
print(nll_loss)
"""
tensor([[1.1180, 1.3059, 1.2426, 1.7773],
[1.4231, 1.2783, 1.7321, 0.0000],
[1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
"""
def test_rnn_lm_model_tie_weights():
model = RnnLmModel(
vocab_size=10,
embedding_dim=10,
hidden_dim=10,
num_layers=2,
tie_weights=True,
)
assert model.input_embedding.weight is model.output_linear.weight
def main():
test_rnn_lm_model()
test_rnn_lm_model_tie_weights()
if __name__ == "__main__":
torch.manual_seed(20211122)
main()

617
icefall/rnn_lm/train.py Executable file
View File

@ -0,0 +1,617 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/train.py \
--start-epoch 0 \
--world-size 2 \
--num-epochs 1 \
--use-fp16 0 \
--embedding-dim 800 \
--hidden-dim 200 \
--num-layers 2\
--batch-size 400
"""
import argparse
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from dataset import get_dataloader
from lhotse.utils import fix_random_seed
from model import RnnLmModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, logs, etc, are saved
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
)
parser.add_argument(
"--lm-data",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data.pt",
help="LM training data",
)
parser.add_argument(
"--lm-data-valid",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
help="LM validation data",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters."""
params = AttributeDict(
{
"max_sent_len": 200,
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
"lr": 1e-3,
"weight_decay": 1e-6,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 200,
"reset_interval": 2000,
"valid_interval": 5000,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
sentence_lengths: torch.Tensor,
is_training: bool,
) -> Tuple[torch.Tensor, MetricsTracker]:
"""Compute the negative log-likelihood loss given a model and its input.
Args:
model:
The NN model, e.g., RnnLmModel.
x:
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
each row starts with SOS ID.
y:
A 2-D tensor. Each row is a shifted version of the corresponding row
in `x` but ends with an EOS ID (before padding).
sentence_lengths:
A 1-D tensor containing number of tokens of each sentence
before padding.
is_training:
True for training. False for validation.
"""
with torch.set_grad_enabled(is_training):
device = model.device
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum()
num_tokens = sentence_lengths.sum().item()
loss_info = MetricsTracker()
# Note: Due to how MetricsTracker() is designed,
# we use "frames" instead of "num_tokens" as a key here
loss_info["frames"] = num_tokens
loss_info["loss"] = loss.detach().item()
return loss, loss_info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all sentences is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
f"batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/current_ppl", this_batch_ppl, params.batch_idx_train
)
tb_writer.add_scalar(
"train/tot_ppl", tot_ppl, params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
f"ppl: {valid_ppl}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/valid_ppl", valid_ppl, params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
is_distributed = world_size > 1
fix_random_seed(params.seed)
if is_distributed:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if is_distributed:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")
optimizer.load_state_dict(checkpoints["optimizer"])
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=is_distributed,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=is_distributed,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if is_distributed:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if is_distributed:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -35,6 +35,8 @@ import torch.distributed as dist
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import average_checkpoints
Pathlike = Union[str, Path]
@ -90,7 +92,11 @@ def str2bool(v):
def setup_logger(
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
log_filename: Pathlike,
log_level: str = "info",
rank: int = 0,
world_size: int = 1,
use_console: bool = True,
) -> None:
"""Setup log level.
@ -100,12 +106,16 @@ def setup_logger(
log_level:
The log level to use, e.g., "debug", "info", "warning", "error",
"critical"
rank:
Rank of this node in DDP training.
world_size:
Number of nodes in DDP training.
use_console:
True to also print logs to console.
"""
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
if world_size > 1:
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
log_filename = f"{log_filename}-{date_time}-{rank}"
else:
@ -799,3 +809,34 @@ def optim_step_and_measure_param_change(
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item()
return relative_change
def load_averaged_model(
model_dir: str,
model: torch.nn.Module,
epoch: int,
avg: int,
device: torch.device,
):
"""
Load a model which is the average of all checkpoints
:param model_dir: a str of the experiment directory
:param model: a torch.nn.Module instance
:param epoch: the last epoch to load from
:param avg: how many models to average from
:param device: move model to this device
:return: A model averaged
"""
# start cannot be negative
start = max(epoch - avg + 1, 0)
filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
return model