From 2213154c69b3c76e3ecb52508c0a45a1fbe5bebe Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 22 Nov 2021 20:05:18 +0800 Subject: [PATCH] Add scripts for training and perplexity computation. --- egs/ptb/LM/prepare.sh | 20 + egs/ptb/LM/rnn_lm/compute_perplexity.py | 228 ++++++++++ egs/ptb/LM/rnn_lm/dataset.py | 60 ++- egs/ptb/LM/rnn_lm/model.py | 145 +++++++ egs/ptb/LM/rnn_lm/test_model.py | 84 ++++ egs/ptb/LM/rnn_lm/train.py | 554 ++++++++++++++++++++++++ 6 files changed, 1089 insertions(+), 2 deletions(-) create mode 100755 egs/ptb/LM/rnn_lm/compute_perplexity.py create mode 100644 egs/ptb/LM/rnn_lm/model.py create mode 100755 egs/ptb/LM/rnn_lm/test_model.py create mode 100755 egs/ptb/LM/rnn_lm/train.py diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh index 33b1e405a..70586785d 100755 --- a/egs/ptb/LM/prepare.sh +++ b/egs/ptb/LM/prepare.sh @@ -72,6 +72,16 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --bpe-model $out_dir/bpe.model \ --lm-data $dl_dir/ptb.train.txt \ --lm-archive $out_dir/lm_data.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.test.txt \ + --lm-archive $out_dir/lm_data-test.pt done fi @@ -91,5 +101,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --in-lm-data $out_dir/lm_data.pt \ --out-lm-data $out_dir/sorted_lm_data.pt \ --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt done fi diff --git a/egs/ptb/LM/rnn_lm/compute_perplexity.py b/egs/ptb/LM/rnn_lm/compute_perplexity.py new file mode 100755 index 000000000..ee64ca0d5 --- /dev/null +++ b/egs/ptb/LM/rnn_lm/compute_perplexity.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./rnn_lm/compute_perplexity.py \ + --epoch 4 \ + --avg 2 \ + --lm-data ./data/bpe_500/sorted_lm_data-test.pt + +""" + +import argparse +import logging +import math +from pathlib import Path + +import torch +from rnn_lm.dataset import get_dataloader +from rnn_lm.model import RnnLmModel + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=49, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lm-data", + type=str, + help="Path to the LM test data for computing perplexity", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=100, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--sos-id", + type=int, + default=1, + help="SOS ID", + ) + + parser.add_argument( + "--eos-id", + type=int, + default=1, + help="EOS ID", + ) + + parser.add_argument( + "--blank-id", + type=int, + default=0, + help="Blank ID", + ) + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lm_data = Path(args.lm_data) + + params = AttributeDict(vars(args)) + print(params) + + setup_logger(f"{params.exp_dir}/log-ppl/") + logging.info("Computing perplexity started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + num_param_requires_grad = sum( + [p.numel() for p in model.parameters() if p.requires_grad] + ) + + logging.info(f"Number of model parameters: {num_param}") + logging.info( + f"Number of model parameters (requires_grad): " + f"{num_param_requires_grad} " + f"({num_param_requires_grad/num_param_requires_grad*100}%)" + ) + + logging.info(f"Loading LM test data from {params.lm_data}") + test_dl = get_dataloader( + filename=params.lm_data, + is_distributed=False, + params=params, + ) + + tot_loss = 0.0 + num_tokens = 0 + num_sentences = 0 + for batch_idx, batch in enumerate(test_dl): + x, y, sentence_lengths = batch + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum().cpu().item() + + tot_loss += loss + num_tokens += sentence_lengths.sum().cpu().item() + num_sentences += x.size(0) + + ppl = math.exp(tot_loss / num_tokens) + logging.info( + f"total nll: {tot_loss}, num tokens: {num_tokens}, " + f"num sentences: {num_sentences}, ppl: {ppl:.3f}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +if __name__ == "__main__": + main() diff --git a/egs/ptb/LM/rnn_lm/dataset.py b/egs/ptb/LM/rnn_lm/dataset.py index a7aaf37ac..eb549f415 100644 --- a/egs/ptb/LM/rnn_lm/dataset.py +++ b/egs/ptb/LM/rnn_lm/dataset.py @@ -19,6 +19,8 @@ from typing import List, Tuple import k2 import torch +from icefall.utils import AttributeDict + class LmDataset(torch.utils.data.Dataset): def __init__( @@ -233,7 +235,7 @@ class LmDatasetCollate: for a sentence starting with `self.sos_id`. It is padded to the max sentence length with `self.blank_id`. - - x, a 2-D tensor of dtype torch.int32; each row contains tokens + - y, a 2-D tensor of dtype torch.int32; each row contains tokens for a sentence ending with `self.eos_id` before padding. Then it is padded to the max sentence length with `self.blank_id`. @@ -257,4 +259,58 @@ class LmDatasetCollate: ) sentence_token_lengths += 1 # plus 1 since we added a SOS - return x, y, sentence_token_lengths + return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths + + +def get_dataloader( + filename: str, + is_distributed: bool, + params: AttributeDict, +) -> torch.utils.data.DataLoader: + """Get dataloader for LM training. + + Args: + filename: + Path to the file containing LM data. The file is assumed to + be generated by `../local/sort_lm_training_data.py`. + is_distributed: + True if using DDP training. False otherwise. + params: + Set `get_params()` from `rnn_lm/train.py` + Returns: + Return a dataloader containing the LM data. + """ + lm_data = torch.load(filename) + + words = lm_data["words"] + sentences = lm_data["sentences"] + sentence_lengths = lm_data["sentence_lengths"] + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=params.max_sent_len, + batch_size=params.batch_size, + ) + if is_distributed: + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, shuffle=True, drop_last=False + ) + else: + sampler = None + + collate_fn = LmDatasetCollate( + sos_id=params.sos_id, + eos_id=params.eos_id, + blank_id=params.blank_id, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + collate_fn=collate_fn, + sampler=sampler, + shuffle=sampler is None, + ) + return dataloader diff --git a/egs/ptb/LM/rnn_lm/model.py b/egs/ptb/LM/rnn_lm/model.py new file mode 100644 index 000000000..86a670a61 --- /dev/null +++ b/egs/ptb/LM/rnn_lm/model.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +import torch.nn.functional as F + + +def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + + max_len = lengths.max() + n = lengths.size(0) + + expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) + + return expaned_lengths >= lengths.unsqueeze(1) + + +class RnnLmModel(torch.nn.Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + hidden_dim: int, + num_layers: int, + tie_weights: bool = False, + ): + """ + Args: + vocab_size: + Vocabulary size of BPE model. + embedding_dim: + Input embedding dimension. + hidden_dim: + Hidden dimension of RNN layers. + num_layers: + Number of RNN layers. + tie_weights: + True to share the weights between the input embedding layer and the + last output linear layer. See https://arxiv.org/abs/1608.05859 + and https://arxiv.org/abs/1611.01462 + """ + super().__init__() + + self.input_embedding = torch.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + ) + + self.rnn = torch.nn.LSTM( + input_size=embedding_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + ) + + self.output_linear = torch.nn.Linear( + in_features=hidden_dim, out_features=vocab_size + ) + + self.vocab_size = vocab_size + if tie_weights: + logging.info("Tying weights") + assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim) + self.output_linear.weight = self.input_embedding.weight + else: + logging.info("Not tying weights") + + def forward( + self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor with shape (N, L). Each row + contains token IDs for a sentence and starts with the SOS token. + y: + A shifted version of `x` and with EOS appended. + lengths: + A 1-D tensor of shape (N,). It contains the sentence lengths + before padding. + Returns: + Return a 2-D tensor of shape (N, L) containing negative log-likelihood + loss values. Note: Loss values for padding positions are set to 0. + """ + assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) + assert lengths.ndim == 1, lengths.ndim + assert x.shape == y.shape, (x.shape, y.shape) + + batch_size = x.size(0) + assert lengths.size(0) == batch_size, (lengths.size(0), batch_size) + + # embedding is of shape (N, L, embedding_dim) + embedding = self.input_embedding(x) + + # Note: We use batch_first==True + rnn_out, _ = self.rnn(embedding) + logits = self.output_linear(rnn_out) + + # Note: No need to use `log_softmax()` here + # since F.cross_entropy() expects unnormalized probabilities + + # nll_loss is of shape (N*L,) + # nll -> negative log-likelihood + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + # Set loss values for padding positions to 0 + mask = make_pad_mask(lengths).reshape(-1) + nll_loss.masked_fill_(mask, 0) + + nll_loss = nll_loss.reshape(batch_size, -1) + + return nll_loss diff --git a/egs/ptb/LM/rnn_lm/test_model.py b/egs/ptb/LM/rnn_lm/test_model.py new file mode 100755 index 000000000..503a74528 --- /dev/null +++ b/egs/ptb/LM/rnn_lm/test_model.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from rnn_lm.model import RnnLmModel, make_pad_mask + + +def test_makd_pad_mask(): + lengths = torch.tensor([1, 3, 2]) + mask = make_pad_mask(lengths) + expected = torch.tensor( + [ + [False, True, True], + [False, False, False], + [False, False, True], + ] + ) + assert torch.all(torch.eq(mask, expected)) + assert (~expected).sum() == lengths.sum() + + +def test_rnn_lm_model(): + vocab_size = 4 + model = RnnLmModel( + vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2 + ) + x = torch.tensor( + [ + [1, 3, 2, 2], + [1, 2, 2, 0], + [1, 2, 0, 0], + ] + ) + y = torch.tensor( + [ + [3, 2, 2, 1], + [2, 2, 1, 0], + [2, 1, 0, 0], + ] + ) + lengths = torch.tensor([4, 3, 2]) + nll_loss = model(x, y, lengths) + print(nll_loss) + """ + tensor([[1.1180, 1.3059, 1.2426, 1.7773], + [1.4231, 1.2783, 1.7321, 0.0000], + [1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=) + """ + + +def test_rnn_lm_model_tie_weights(): + model = RnnLmModel( + vocab_size=10, + embedding_dim=10, + hidden_dim=10, + num_layers=2, + tie_weights=True, + ) + assert model.input_embedding.weight is model.output_linear.weight + + +def main(): + test_makd_pad_mask() + test_rnn_lm_model() + test_rnn_lm_model_tie_weights() + + +if __name__ == "__main__": + torch.manual_seed(20211122) + main() diff --git a/egs/ptb/LM/rnn_lm/train.py b/egs/ptb/LM/rnn_lm/train.py new file mode 100755 index 000000000..c1cbdf377 --- /dev/null +++ b/egs/ptb/LM/rnn_lm/train.py @@ -0,0 +1,554 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./rnn_lm/train.py \ + --world-size 2 \ + --start-epoch 4 +""" + +import argparse +import logging +import math +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from lhotse.utils import fix_random_seed +from rnn_lm.dataset import get_dataloader +from rnn_lm.model import RnnLmModel +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_env_info, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + exp_dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, logs, etc, are saved + """, + ) + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters.""" + + params = AttributeDict( + { + # LM training/validation data + "lm_data": "data/bpe_500/sorted_lm_data.pt", + "lm_data_valid": "data/bpe_500/sorted_lm_data-valid.pt", + "batch_size": 50, + "max_sent_len": 200, + "sos_id": 1, + "eos_id": 1, + "blank_id": 0, + # model related + # + # vocab size of the BPE model + "vocab_size": 500, + "embedding_dim": 2048, + "hidden_dim": 2048, + "num_layers": 4, + # + "lr": 1e-3, + "weight_decay": 1e-6, + # + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 300, + "env_info": get_env_info(), + } + ) + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + logging.info(f"Loading checkpoint: {filename}") + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + sentence_lengths: torch.Tensor, + is_training: bool, +) -> Tuple[torch.Tensor, MetricsTracker]: + """Compute the negative log-likelihood loss given a model and its input. + Args: + model: + The NN model, e.g., RnnLmModel. + x: + A 2-D tensor. Each row contains BPE token IDs for a sentence. Also, + each row starts with SOS ID. + y: + A 2-D tensor. Each row is a shifted version of the corresponding row + in `x` but ends with an EOS ID (before padding). + sentence_lengths: + A 1-D tensor containing number of tokens of each sentence + before padding. + is_training: + True for training. False for validation. + """ + with torch.set_grad_enabled(is_training): + device = model.device + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum() + + num_tokens = sentence_lengths.sum().item() + + loss_info = MetricsTracker() + # Note: Due to how MetricsTracker() is designed, + # we use "frames" instead of "num_tokens" as a key here + loss_info["frames"] = num_tokens + loss_info["loss"] = loss.detach().item() + return loss, loss_info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + x, y, sentence_lengths = batch + + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=False, + ) + + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all sentences is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + x, y, sentence_lengths = batch + batch_size = x.size(0) + + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=True, + ) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + # Note: "frames" here means "num_tokens" + this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) + tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " + f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/current_ppl", this_batch_ppl, params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/tot_ppl", tot_ppl, params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + + valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) + logging.info( + f"Epoch {params.cur_epoch}, validation: {valid_info}, " + f"ppl: {valid_ppl}" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/valid_ppl", valid_ppl, params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + if checkpoints: + logging.info("Load optimizer state_dict from checkpoint") + optimizer.load_state_dict(checkpoints["optimizer"]) + + logging.info(f"Loading LM training data from {params.lm_data}") + train_dl = get_dataloader( + filename=params.lm_data, + is_distributed=world_size > 1, + params=params, + ) + + logging.info(f"Loading LM validation data from {params.lm_data_valid}") + valid_dl = get_dataloader( + filename=params.lm_data_valid, + is_distributed=world_size > 1, + params=params, + ) + + # Note: No learning rate scheduler is used here + for epoch in range(params.start_epoch, params.num_epochs): + if world_size > 1: + train_dl.sampler.set_epoch(epoch) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main()