diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 5eb07fae5..3c5027c77 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1299,17 +1299,18 @@ You can find the tensorboard log at: +and the RNN-LM pre-trained model: + + The tensorboard log for training is available at diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 177e33a6e..0e8247b8d 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import load_checkpoint from icefall.decode import ( get_lattice, nbest_decoding, @@ -38,15 +38,19 @@ from icefall.decode import ( one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, + rescore_with_rnn_lm, rescore_with_whole_lattice, ) from icefall.env import get_env_info from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, get_texts, + load_averaged_model, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -93,7 +97,9 @@ def get_parser(): is the decoding result. - (5) attention-decoder. Extract n paths from the LM rescored lattice, the path with the highest score is the decoding result. - - (6) nbest-oracle. Its WER is the lower bound of any n-best + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -105,7 +111,7 @@ def get_parser(): default=100, help="""Number of paths for n-best based decoding method. Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle """, ) @@ -116,7 +122,7 @@ def get_parser(): help="""The scale to be applied to `lattice.scores`. It's needed if you use any kinds of n-best based rescoring. Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle A smaller value results in more unique paths. """, ) @@ -139,11 +145,67 @@ def get_parser(): "--lm-dir", type=str, default="data/lm", - help="""The LM dir. + help="""The n-gram LM dir. It should contain either G_4_gram.pt or G_4_gram.fst.txt """, ) + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + return parser @@ -173,6 +235,7 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, + rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -205,6 +268,8 @@ def decode_one_batch( model: The neural model. + rnn_lm_model: + The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -330,6 +395,7 @@ def decode_one_batch( "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "rnn-lm", ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] @@ -357,8 +423,6 @@ def decode_one_batch( G_with_epsilon_loops=G, lm_scale_list=None, ) - # TODO: pass `lattice` instead of `rescored_lattice` to - # `rescore_with_attention_decoder` best_path_dict = rescore_with_attention_decoder( lattice=rescored_lattice, @@ -370,6 +434,26 @@ def decode_one_batch( eos_id=eos_id, nbest_scale=params.nbest_scale, ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -388,6 +472,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, + rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -405,6 +490,8 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. + rnn_lm_model: + The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -442,6 +529,7 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, + rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model, @@ -490,7 +578,7 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - if params.method == "attention-decoder": + if params.method in ("attention-decoder", "rnn-lm"): # Set it to False since there are too many logs. enable_log = False else: @@ -566,6 +654,10 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + if params.method == "ctc-decoding": HLG = None H = k2.ctc_topo( @@ -590,6 +682,7 @@ def main(): "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "rnn-lm", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -621,7 +714,11 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) @@ -648,20 +745,40 @@ def main(): if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() @@ -678,6 +795,7 @@ def main(): dl=test_dl, params=params, model=model, + rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model, diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 94d23afed..030122aa7 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -23,6 +23,7 @@ This file downloads the following LibriSpeech LM files: - 4-gram.arpa.gz - librispeech-vocab.txt - librispeech-lexicon.txt + - librispeech-lm-norm.txt.gz from http://www.openslr.org/resources/11 and save them in the user provided directory. @@ -61,6 +62,7 @@ def main(out_dir: str): "4-gram.arpa.gz", "librispeech-vocab.txt", "librispeech-lexicon.txt", + "librispeech-lm-norm.txt.gz", ) for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py new file mode 100755 index 000000000..5070341f1 --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `bpe.model` and a text file such as +./download/lm/librispeech-lm-norm.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_bpe_500. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the following format: + + 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 + containing the BPE representations of each word, indexed by + integer word ID. (These integer word IDS are present in + 'lm_data'). The sentencepiece object can be used to turn the + words and BPE units into string form. + 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype + torch.int32 containing all the sentences, as word-ids (we don't + output the string form of this directly but it can be worked out + together with 'words' and the bpe.model). + 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing + number of BPE tokens of each sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + type=str, + help="Input BPE model, e.g. data/bpe_500/bpe.model", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/pb.train.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + if "librispeech-lm-norm" in args.lm_data: + num_lines_in_total = 40418261.0 + step = 5000000 + elif "valid" in args.lm_data: + num_lines_in_total = 5567.0 + step = 3000 + elif "test" in args.lm_data: + num_lines_in_total = 5559.0 + step = 3000 + else: + num_lines_in_total = None + step = None + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed/num_lines_in_total*100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_bpe = sp.encode(w) + word2index[w] = len(word2bpe) + word2bpe.append(w_bpe) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2bpe) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} " + f"({i/num_sentences*100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor( + sentence_lengths, dtype=torch.int32 + ) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/sort_lm_training_data.py b/egs/librispeech/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..bb26b5f5c --- /dev/null +++ b/egs/librispeech/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../ptb/LM/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index bc5812810..42aba9572 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -38,7 +38,6 @@ def get_args(): "--lang-dir", type=str, help="""Input and output directory. - It should contain the training corpus: transcript_words.txt. The generated bpe.model is saved to this directory. """, ) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 17a638502..94e003036 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -24,6 +24,7 @@ stop_stage=100 # - 4-gram.arpa # - librispeech-vocab.txt # - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz # # - $dl_dir/musan # This directory contains the following directories downloaded from @@ -40,9 +41,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - # 5000 - # 2000 - # 1000 + 5000 + 2000 + 1000 500 ) @@ -278,3 +279,99 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then ./local/compile_lg.py --lang-dir $lang_dir done fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/test.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/ptb/LM/README.md b/egs/ptb/LM/README.md new file mode 100644 index 000000000..7629a950d --- /dev/null +++ b/egs/ptb/LM/README.md @@ -0,0 +1,18 @@ +## Description + +(Note: the experiments here are only about language modeling) + +ptb is short for Penn Treebank. + + +About the Penn Treebank corpus: + - This corpus is free for research purposes + - ptb.train.txt: train set + - ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training) + - ptb.test.txt: test set for reporting perplexity + +You can download the dataset from one of the following URLs: + +- https://github.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage +- http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz +- https://deepai.org/dataset/penn-treebank diff --git a/egs/ptb/LM/local/prepare_lm_training_data.py b/egs/ptb/LM/local/prepare_lm_training_data.py new file mode 120000 index 000000000..eebce957e --- /dev/null +++ b/egs/ptb/LM/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py new file mode 100755 index 000000000..af54dbd07 --- /dev/null +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py new file mode 100755 index 000000000..877720e7b --- /dev/null +++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import sentencepiece as spm +import torch + + +def main(): + lm_training_data = Path("./data/bpe_500/lm_data.pt") + bpe_model = Path("./data/bpe_500/bpe.model") + if not lm_training_data.exists(): + logging.warning(f"{lm_training_data} does not exist - skipping") + return + + if not bpe_model.exists(): + logging.warning(f"{bpe_model} does not exist - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(str(bpe_model)) + + data = torch.load(lm_training_data) + words2bpe = data["words"] + sentences = data["sentences"] + + ss = [] + unk = sp.decode(sp.unk_id()).strip() + for i in range(10): + s = sp.decode(words2bpe[sentences[i]].values.tolist()) + s = s.replace(unk, "") + ss.append(s) + + for s in ss: + print(s) + # You can compare the output with the first 10 lines of ptb.train.txt + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ptb/LM/local/train_bpe_model.py b/egs/ptb/LM/local/train_bpe_model.py new file mode 120000 index 000000000..6f018a0e2 --- /dev/null +++ b/egs/ptb/LM/local/train_bpe_model.py @@ -0,0 +1 @@ +../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh new file mode 100755 index 000000000..70586785d --- /dev/null +++ b/egs/ptb/LM/prepare.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download +# The following files will be downloaded to $dl_dir +# - ptb.train.txt +# - ptb.valid.txt +# - ptb.test.txt + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/bpe_xxx, data/bpe_yyy +# if the array contains xxx, yyy +vocab_sizes=( + 500 + 1000 + 2000 + 5000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +mkdir -p $dl_dir + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download data" + if [ ! -f $dl_dir/.complete ]; then + url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/ + wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt + wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt + wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt + touch $dl_dir/.complete + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Train BPE model" + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/train_bpe_model.py \ + --out-dir $out_dir \ + --vocab-size $vocab_size \ + --transcript $dl_dir/ptb.train.txt + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Generate LM training data" + # Note: ptb.train.txt has already been normalized + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.train.txt \ + --lm-archive $out_dir/lm_data.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Sort LM training data" + # Sort LM training data generated in stage 1 + # by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/ptb/LM/shared b/egs/ptb/LM/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/ptb/LM/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/icefall/decode.py b/icefall/decode.py index 3ba899b4e..680e29619 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -20,7 +20,34 @@ from typing import Dict, List, Optional, Union import k2 import torch -from icefall.utils import get_texts +from icefall.utils import add_eos, add_sos, get_texts + +DEFAULT_LM_SCALE = [ + 0.01, + 0.05, + 0.08, + 0.1, + 0.3, + 0.5, + 0.6, + 0.7, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + 1.5, + 1.7, + 1.9, + 2.0, + 2.1, + 2.2, + 2.3, + 2.5, + 3.0, + 4.0, + 5.0, +] def _intersect_device( @@ -952,3 +979,161 @@ def rescore_with_attention_decoder( key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" ans[key] = best_path return ans + + +def rescore_with_rnn_lm( + lattice: k2.Fsa, + num_paths: int, + rnn_lm_model: torch.nn.Module, + model: torch.nn.Module, + memory: torch.Tensor, + memory_key_padding_mask: Optional[torch.Tensor], + sos_id: int, + eos_id: int, + blank_id: int, + nbest_scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, + rnn_lm_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(T, N, C)`. + memory_key_padding_mask: + The padding mask for memory with shape `(N, T)`. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + rnn_lm_scale: + Optional. It specifies the scale for RNN LM scores. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # nbest.fsa.scores are all 0s at this point + + nbest = nbest.intersect(lattice) + # Now nbest.fsa has its scores set. + # Also, nbest.fsa inherits the attributes from `lattice`. + assert hasattr(nbest.fsa, "lm_scores") + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + + # The `tokens` attribute is set inside `compile_hlg.py` + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # the shape of memory is (T, N, C), so we use axis=1 here + expanded_memory = memory.index_select(1, path_to_utt_map) + + if memory_key_padding_mask is not None: + # The shape of memory_key_padding_mask is (N, T), so we + # use axis=0 here. + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_utt_map + ) + else: + expanded_memory_key_padding_mask = None + + # remove axis corresponding to states. + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + token_ids = tokens.tolist() + + if len(token_ids) == 0: + print("Warning: rescore_with_attention_decoder(): empty token-ids") + return None + + nll = model.decoder_nll( + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + # Now for RNN LM + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_ids) + + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ngram_lm_scale_list = DEFAULT_LM_SCALE + attention_scale_list = DEFAULT_LM_SCALE + rnn_lm_scale_list = DEFAULT_LM_SCALE + + if ngram_lm_scale: + ngram_lm_scale_list = [ngram_lm_scale] + + if attention_scale: + attention_scale_list = [attention_scale] + + if rnn_lm_scale: + rnn_lm_scale_list = [rnn_lm_scale] + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + for r_scale in rnn_lm_scale_list: + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores.values + + a_scale * attention_scores + + r_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa + ans[key] = best_path + return ans diff --git a/icefall/dist.py b/icefall/dist.py index 203c7c563..6334f9c13 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,14 +21,46 @@ import torch from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = ( - "12354" if master_port is None else str(master_port) - ) - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) +def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): + """ + rank and world_size are used only if use_ddp_launch is False. + """ + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = ( + "12354" if master_port is None else str(master_port) + ) + + if use_ddp_launch is False: + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + else: + dist.init_process_group("nccl") def cleanup_dist(): dist.destroy_process_group() + + +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def get_rank(): + if "RANK" in os.environ: + return int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.rank() + else: + return 1 + + +def get_local_rank(): + return int(os.environ.get("LOCAL_RANK", 0)) diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py new file mode 100755 index 000000000..550801a8f --- /dev/null +++ b/icefall/rnn_lm/compute_perplexity.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./rnn_lm/compute_perplexity.py \ + --epoch 4 \ + --avg 2 \ + --lm-data ./data/bpe_500/sorted_lm_data-test.pt + +""" + +import argparse +import logging +import math +from pathlib import Path + +import torch +from dataset import get_dataloader +from model import RnnLmModel + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=49, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lm-data", + type=str, + help="Path to the LM test data for computing perplexity", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=100, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--sos-id", + type=int, + default=1, + help="SOS ID", + ) + + parser.add_argument( + "--eos-id", + type=int, + default=1, + help="EOS ID", + ) + + parser.add_argument( + "--blank-id", + type=int, + default=0, + help="Blank ID", + ) + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lm_data = Path(args.lm_data) + + params = AttributeDict(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ppl/") + logging.info("Computing perplexity started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + num_param_requires_grad = sum( + [p.numel() for p in model.parameters() if p.requires_grad] + ) + + logging.info(f"Number of model parameters: {num_param}") + logging.info( + f"Number of model parameters (requires_grad): " + f"{num_param_requires_grad} " + f"({num_param_requires_grad/num_param_requires_grad*100}%)" + ) + + logging.info(f"Loading LM test data from {params.lm_data}") + test_dl = get_dataloader( + filename=params.lm_data, + is_distributed=False, + params=params, + ) + + tot_loss = 0.0 + num_tokens = 0 + num_sentences = 0 + for batch_idx, batch in enumerate(test_dl): + x, y, sentence_lengths = batch + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum().cpu().item() + + tot_loss += loss + num_tokens += sentence_lengths.sum().cpu().item() + num_sentences += x.size(0) + + ppl = math.exp(tot_loss / num_tokens) + logging.info( + f"total nll: {tot_loss}, num tokens: {num_tokens}, " + f"num sentences: {num_sentences}, ppl: {ppl:.3f}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +if __name__ == "__main__": + main() diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py new file mode 100644 index 000000000..598e329c4 --- /dev/null +++ b/icefall/rnn_lm/dataset.py @@ -0,0 +1,218 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import k2 +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from icefall.utils import AttributeDict, add_eos, add_sos + + +class LmDataset(torch.utils.data.Dataset): + def __init__( + self, + sentences: k2.RaggedTensor, + words: k2.RaggedTensor, + sentence_lengths: torch.Tensor, + max_sent_len: int, + batch_size: int, + ): + """ + Args: + sentences: + A ragged tensor of dtype torch.int32 with 2 axes [sentence][word]. + words: + A ragged tensor of dtype torch.int32 with 2 axes [word][token]. + sentence_lengths: + A 1-D tensor of dtype torch.int32 containing number of tokens + of each sentence. + max_sent_len: + Maximum sentence length. It is used to change the batch size + dynamically. In general, we try to keep the product of + "max_sent_len in a batch" and "num_of_sent in a batch" being + a constant. + batch_size: + The expected batch size. It is changed dynamically according + to the "max_sent_len". + + See `../local/prepare_lm_training_data.py` for how `sentences` and + `words` are generated. We assume that `sentences` are sorted by length. + See `../local/sort_lm_training_data.py`. + """ + super().__init__() + self.sentences = sentences + self.words = words + + sentence_lengths = sentence_lengths.tolist() + + assert batch_size > 0, batch_size + assert max_sent_len > 1, max_sent_len + batch_indexes = [] + num_sentences = sentences.dim0 + cur = 0 + while cur < num_sentences: + sz = sentence_lengths[cur] // max_sent_len + 1 + # Assume the current sentence has 3 * max_sent_len tokens, + # in the worst case, the subsequent sentences also have + # this number of tokens, we should reduce the batch size + # so that this batch will not contain too many tokens + actual_batch_size = batch_size // sz + 1 + actual_batch_size = min(actual_batch_size, batch_size) + end = cur + actual_batch_size + end = min(end, num_sentences) + this_batch_indexes = torch.arange(cur, end).tolist() + batch_indexes.append(this_batch_indexes) + cur = end + assert batch_indexes[-1][-1] == num_sentences - 1 + + self.batch_indexes = k2.RaggedTensor(batch_indexes) + + def __len__(self) -> int: + """Return number of batches in this dataset""" + return self.batch_indexes.dim0 + + def __getitem__(self, i: int) -> k2.RaggedTensor: + """Get the i'th batch in this dataset + Return a ragged tensor with 2 axes [sentence][token]. + """ + assert 0 <= i < len(self), i + + # indexes is a 1-D tensor containing sentence indexes + indexes = self.batch_indexes[i] + + # sentence_words is a ragged tensor with 2 axes + # [sentence][word] + sentence_words = self.sentences[indexes] + + # in case indexes contains only 1 entry, the returned + # sentence_words is a 1-D tensor, we have to convert + # it to a ragged tensor + if isinstance(sentence_words, torch.Tensor): + sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0)) + + # sentence_word_tokens is a ragged tensor with 3 axes + # [sentence][word][token] + sentence_word_tokens = self.words.index(sentence_words) + assert sentence_word_tokens.num_axes == 3 + + sentence_tokens = sentence_word_tokens.remove_axis(1) + return sentence_tokens + + +class LmDatasetCollate: + def __init__(self, sos_id: int, eos_id: int, blank_id: int): + """ + Args: + sos_id: + Token ID of the SOS symbol. + eos_id: + Token ID of the EOS symbol. + blank_id: + Token ID of the blank symbol. + """ + self.sos_id = sos_id + self.eos_id = eos_id + self.blank_id = blank_id + + def __call__( + self, batch: List[k2.RaggedTensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return a tuple containing 3 tensors: + + - x, a 2-D tensor of dtype torch.int32; each row contains tokens + for a sentence starting with `self.sos_id`. It is padded to + the max sentence length with `self.blank_id`. + + - y, a 2-D tensor of dtype torch.int32; each row contains tokens + for a sentence ending with `self.eos_id` before padding. + Then it is padded to the max sentence length with + `self.blank_id`. + + - lengths, a 2-D tensor of dtype torch.int32, containing the number of + tokens of each sentence before padding. + """ + # The batching stuff has already been done in LmDataset + assert len(batch) == 1 + sentence_tokens = batch[0] + row_splits = sentence_tokens.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id) + sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id) + + x = sentence_tokens_with_sos.pad( + mode="constant", padding_value=self.blank_id + ) + y = sentence_tokens_with_eos.pad( + mode="constant", padding_value=self.blank_id + ) + sentence_token_lengths += 1 # plus 1 since we added a SOS + + return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths + + +def get_dataloader( + filename: str, + is_distributed: bool, + params: AttributeDict, +) -> torch.utils.data.DataLoader: + """Get dataloader for LM training. + + Args: + filename: + Path to the file containing LM data. The file is assumed to + be generated by `../local/sort_lm_training_data.py`. + is_distributed: + True if using DDP training. False otherwise. + params: + Set `get_params()` from `rnn_lm/train.py` + Returns: + Return a dataloader containing the LM data. + """ + lm_data = torch.load(filename) + + words = lm_data["words"] + sentences = lm_data["sentences"] + sentence_lengths = lm_data["sentence_lengths"] + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=params.max_sent_len, + batch_size=params.batch_size, + ) + if is_distributed: + sampler = DistributedSampler(dataset, shuffle=True, drop_last=False) + else: + sampler = None + + collate_fn = LmDatasetCollate( + sos_id=params.sos_id, + eos_id=params.eos_id, + blank_id=params.blank_id, + ) + + dataloader = DataLoader( + dataset, + batch_size=1, + collate_fn=collate_fn, + sampler=sampler, + shuffle=sampler is None, + ) + return dataloader diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py new file mode 100644 index 000000000..094035fce --- /dev/null +++ b/icefall/rnn_lm/export.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from model import RnnLmModel + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, load_averaged_model, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = AttributeDict({}) + params.update(vars(args)) + + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py new file mode 100644 index 000000000..88b2cc41f --- /dev/null +++ b/icefall/rnn_lm/model.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + + +class RnnLmModel(torch.nn.Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + hidden_dim: int, + num_layers: int, + tie_weights: bool = False, + ): + """ + Args: + vocab_size: + Vocabulary size of BPE model. + embedding_dim: + Input embedding dimension. + hidden_dim: + Hidden dimension of RNN layers. + num_layers: + Number of RNN layers. + tie_weights: + True to share the weights between the input embedding layer and the + last output linear layer. See https://arxiv.org/abs/1608.05859 + and https://arxiv.org/abs/1611.01462 + """ + super().__init__() + + self.input_embedding = torch.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + ) + + self.rnn = torch.nn.LSTM( + input_size=embedding_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + ) + + self.output_linear = torch.nn.Linear( + in_features=hidden_dim, out_features=vocab_size + ) + + self.vocab_size = vocab_size + if tie_weights: + logging.info("Tying weights") + assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim) + self.output_linear.weight = self.input_embedding.weight + else: + logging.info("Not tying weights") + + def forward( + self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor with shape (N, L). Each row + contains token IDs for a sentence and starts with the SOS token. + y: + A shifted version of `x` and with EOS appended. + lengths: + A 1-D tensor of shape (N,). It contains the sentence lengths + before padding. + Returns: + Return a 2-D tensor of shape (N, L) containing negative log-likelihood + loss values. Note: Loss values for padding positions are set to 0. + """ + assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) + assert lengths.ndim == 1, lengths.ndim + assert x.shape == y.shape, (x.shape, y.shape) + + batch_size = x.size(0) + assert lengths.size(0) == batch_size, (lengths.size(0), batch_size) + + # embedding is of shape (N, L, embedding_dim) + embedding = self.input_embedding(x) + + # Note: We use batch_first==True + rnn_out, _ = self.rnn(embedding) + logits = self.output_linear(rnn_out) + + # Note: No need to use `log_softmax()` here + # since F.cross_entropy() expects unnormalized probabilities + + # nll_loss is of shape (N*L,) + # nll -> negative log-likelihood + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + # Set loss values for padding positions to 0 + mask = make_pad_mask(lengths).reshape(-1) + nll_loss.masked_fill_(mask, 0) + + nll_loss = nll_loss.reshape(batch_size, -1) + + return nll_loss diff --git a/icefall/rnn_lm/test_dataset.py b/icefall/rnn_lm/test_dataset.py new file mode 100755 index 000000000..bf961f54b --- /dev/null +++ b/icefall/rnn_lm/test_dataset.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import k2 +import torch +from rnn_lm.dataset import LmDataset, LmDatasetCollate + + +def main(): + sentences = k2.RaggedTensor( + [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] + ) + words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) + + num_sentences = sentences.dim0 + + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) + + indices = torch.argsort(sentence_lengths, descending=True) + sentences = sentences[indices.to(torch.int32)] + sentence_lengths = sentence_lengths[indices] + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=3, + batch_size=4, + ) + + collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, collate_fn=collate_fn + ) + + for i in dataloader: + print(i) + # I've checked the output manually; the output is as expected. + + +if __name__ == "__main__": + main() diff --git a/icefall/rnn_lm/test_dataset_ddp.py b/icefall/rnn_lm/test_dataset_ddp.py new file mode 100755 index 000000000..48fbb19cb --- /dev/null +++ b/icefall/rnn_lm/test_dataset_ddp.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import k2 +import torch +import torch.multiprocessing as mp +from rnn_lm.dataset import LmDataset, LmDatasetCollate +from torch import distributed as dist + + +def generate_data(): + sentences = k2.RaggedTensor( + [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] + ) + words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) + + num_sentences = sentences.dim0 + + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) + + indices = torch.argsort(sentence_lengths, descending=True) + sentences = sentences[indices.to(torch.int32)] + sentence_lengths = sentence_lengths[indices] + + return sentences, words, sentence_lengths + + +def run(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12352" + + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + sentences, words, sentence_lengths = generate_data() + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=3, + batch_size=4, + ) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, shuffle=True, drop_last=False + ) + + collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + collate_fn=collate_fn, + sampler=sampler, + shuffle=False, + ) + + for i in dataloader: + print(f"rank: {rank}", i) + + dist.destroy_process_group() + + +def main(): + world_size = 2 + mp.spawn(run, args=(world_size,), nprocs=world_size, join=True) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/rnn_lm/test_model.py b/icefall/rnn_lm/test_model.py new file mode 100755 index 000000000..5a216a3fb --- /dev/null +++ b/icefall/rnn_lm/test_model.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from rnn_lm.model import RnnLmModel + + +def test_rnn_lm_model(): + vocab_size = 4 + model = RnnLmModel( + vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2 + ) + x = torch.tensor( + [ + [1, 3, 2, 2], + [1, 2, 2, 0], + [1, 2, 0, 0], + ] + ) + y = torch.tensor( + [ + [3, 2, 2, 1], + [2, 2, 1, 0], + [2, 1, 0, 0], + ] + ) + lengths = torch.tensor([4, 3, 2]) + nll_loss = model(x, y, lengths) + print(nll_loss) + """ + tensor([[1.1180, 1.3059, 1.2426, 1.7773], + [1.4231, 1.2783, 1.7321, 0.0000], + [1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=) + """ + + +def test_rnn_lm_model_tie_weights(): + model = RnnLmModel( + vocab_size=10, + embedding_dim=10, + hidden_dim=10, + num_layers=2, + tie_weights=True, + ) + assert model.input_embedding.weight is model.output_linear.weight + + +def main(): + test_rnn_lm_model() + test_rnn_lm_model_tie_weights() + + +if __name__ == "__main__": + torch.manual_seed(20211122) + main() diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py new file mode 100755 index 000000000..bb5f03fb9 --- /dev/null +++ b/icefall/rnn_lm/train.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --num-epochs 1 \ + --use-fp16 0 \ + --embedding-dim 800 \ + --hidden-dim 200 \ + --num-layers 2\ + --batch-size 400 + +""" + +import argparse +import logging +import math +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from dataset import get_dataloader +from lhotse.utils import fix_random_seed +from model import RnnLmModel +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + exp_dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, logs, etc, are saved + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=50, + ) + + parser.add_argument( + "--lm-data", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data.pt", + help="LM training data", + ) + + parser.add_argument( + "--lm-data-valid", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", + help="LM validation data", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters.""" + + params = AttributeDict( + { + "max_sent_len": 200, + "sos_id": 1, + "eos_id": 1, + "blank_id": 0, + "lr": 1e-3, + "weight_decay": 1e-6, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 200, + "reset_interval": 2000, + "valid_interval": 5000, + "env_info": get_env_info(), + } + ) + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + logging.info(f"Loading checkpoint: {filename}") + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + sentence_lengths: torch.Tensor, + is_training: bool, +) -> Tuple[torch.Tensor, MetricsTracker]: + """Compute the negative log-likelihood loss given a model and its input. + Args: + model: + The NN model, e.g., RnnLmModel. + x: + A 2-D tensor. Each row contains BPE token IDs for a sentence. Also, + each row starts with SOS ID. + y: + A 2-D tensor. Each row is a shifted version of the corresponding row + in `x` but ends with an EOS ID (before padding). + sentence_lengths: + A 1-D tensor containing number of tokens of each sentence + before padding. + is_training: + True for training. False for validation. + """ + with torch.set_grad_enabled(is_training): + device = model.device + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum() + + num_tokens = sentence_lengths.sum().item() + + loss_info = MetricsTracker() + # Note: Due to how MetricsTracker() is designed, + # we use "frames" instead of "num_tokens" as a key here + loss_info["frames"] = num_tokens + loss_info["loss"] = loss.detach().item() + return loss, loss_info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + x, y, sentence_lengths = batch + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=False, + ) + + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all sentences is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + x, y, sentence_lengths = batch + batch_size = x.size(0) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=True, + ) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + # Note: "frames" here means "num_tokens" + this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) + tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " + f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/current_ppl", this_batch_ppl, params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/tot_ppl", tot_ppl, params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + + valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) + logging.info( + f"Epoch {params.cur_epoch}, validation: {valid_info}, " + f"ppl: {valid_ppl}" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/valid_ppl", valid_ppl, params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + is_distributed = world_size > 1 + + fix_random_seed(params.seed) + if is_distributed: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if is_distributed: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + if checkpoints: + logging.info("Load optimizer state_dict from checkpoint") + optimizer.load_state_dict(checkpoints["optimizer"]) + + logging.info(f"Loading LM training data from {params.lm_data}") + train_dl = get_dataloader( + filename=params.lm_data, + is_distributed=is_distributed, + params=params, + ) + + logging.info(f"Loading LM validation data from {params.lm_data_valid}") + valid_dl = get_dataloader( + filename=params.lm_data_valid, + is_distributed=is_distributed, + params=params, + ) + + # Note: No learning rate scheduler is used here + for epoch in range(params.start_epoch, params.num_epochs): + if is_distributed: + train_dl.sampler.set_epoch(epoch) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if is_distributed: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/utils.py b/icefall/utils.py index b38574f0c..10a2e6301 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -35,6 +35,8 @@ import torch.distributed as dist import torch.nn as nn from torch.utils.tensorboard import SummaryWriter +from icefall.checkpoint import average_checkpoints + Pathlike = Union[str, Path] @@ -90,7 +92,11 @@ def str2bool(v): def setup_logger( - log_filename: Pathlike, log_level: str = "info", use_console: bool = True + log_filename: Pathlike, + log_level: str = "info", + rank: int = 0, + world_size: int = 1, + use_console: bool = True, ) -> None: """Setup log level. @@ -100,12 +106,16 @@ def setup_logger( log_level: The log level to use, e.g., "debug", "info", "warning", "error", "critical" + rank: + Rank of this node in DDP training. + world_size: + Number of nodes in DDP training. + use_console: + True to also print logs to console. """ now = datetime.now() date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - if dist.is_available() and dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() + if world_size > 1: formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa log_filename = f"{log_filename}-{date_time}-{rank}" else: @@ -799,3 +809,34 @@ def optim_step_and_measure_param_change( delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) relative_change[n] = delta.item() return relative_change + + +def load_averaged_model( + model_dir: str, + model: torch.nn.Module, + epoch: int, + avg: int, + device: torch.device, +): + """ + Load a model which is the average of all checkpoints + + :param model_dir: a str of the experiment directory + :param model: a torch.nn.Module instance + + :param epoch: the last epoch to load from + :param avg: how many models to average from + :param device: move model to this device + + :return: A model averaged + """ + + # start cannot be negative + start = max(epoch - avg + 1, 0) + filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)] + + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + return model