From bc122eec64440b7ddb14138cbb5932042dc7eb7e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 17 Nov 2021 14:36:40 +0800 Subject: [PATCH] Ready to train a masked LM with vocab size 500. --- egs/librispeech/ASR/conformer_lm/dataset.py | 8 ++- egs/librispeech/ASR/conformer_lm/train.py | 63 ++++++++++++++------- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 8d24873ed..078268732 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -33,7 +33,13 @@ class LmDataset(torch.utils.data.Dataset): # self.sentences[i] returns a 1-D tensor containing word indexes # self.words[self.sentences[i]] returns a ragged tensor with axes # [word][token]. - return self.words[self.sentences[i]].values.tolist() + word_tokens = self.words[self.sentences[i]] + # TODO(fangjun): Need to figure out why `word_tokens` + # can be a torch.Tensor + if isinstance(word_tokens, torch.Tensor): + return word_tokens + else: + return word_tokens.values.tolist() def load_train_test_lm_dataset(archive_fn: Union[str,Path], diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 2d1c1a4c3..54caa656f 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -74,6 +74,44 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_lm/exp", + help="Path to save files generated during training", + ) + + parser.add_argument( + "--lm-dataset", + type=str, + default="data/lm_training_500/lm_data.pt", + help="LM training data. See local/prepare_lm_training_data.py", + ) + + parser.add_argument( + "--num-tokens", + type=int, + default="500", + help="BPE model vocab size", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + exp_dir/epoch-{start_epoch-1}.pt + """, + ) + return parser @@ -88,19 +126,11 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - lr: It specifies the initial learning rate - feature_dim: The model input dim. It has to match the one used in computing features. - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - num_epochs: Number of epochs to train. - - num_valid_batches: Number of batches of validation data to use each time we compute validation loss @@ -132,19 +162,12 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - # exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate. - # exp_4, vs. exp_3, is using the Gloam optimizer with - # in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor - # as well as the exponential part. - # exp_6, we change the decay from 0.85 to 0.9. - "exp_dir": Path("conformer_lm/exp_6"), - "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), - "num_tokens": 5000, + # "exp_dir": Path("conformer_lm/exp_6"), + # "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), + # "num_tokens": 5000, "blank_sym": 0, "bos_sym": 1, "eos_sym": 1, - "start_epoch": 2, - "num_epochs": 20, "num_valid_batches": 200, "symbols_per_batch": 5000, "best_train_loss": float("inf"), @@ -152,7 +175,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, "beam_size": 10, @@ -607,6 +630,8 @@ def run(rank, world_size, args): def main(): parser = get_parser() args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lm_dataset = Path(args.lm_dataset) world_size = args.world_size assert world_size >= 1