Ready to train a masked LM with vocab size 500.

This commit is contained in:
Fangjun Kuang 2021-11-17 14:36:40 +08:00
parent 7c3ab28a68
commit bc122eec64
2 changed files with 51 additions and 20 deletions

View File

@ -33,7 +33,13 @@ class LmDataset(torch.utils.data.Dataset):
# self.sentences[i] returns a 1-D tensor containing word indexes
# self.words[self.sentences[i]] returns a ragged tensor with axes
# [word][token].
return self.words[self.sentences[i]].values.tolist()
word_tokens = self.words[self.sentences[i]]
# TODO(fangjun): Need to figure out why `word_tokens`
# can be a torch.Tensor
if isinstance(word_tokens, torch.Tensor):
return word_tokens
else:
return word_tokens.values.tolist()
def load_train_test_lm_dataset(archive_fn: Union[str,Path],

View File

@ -74,6 +74,44 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_lm/exp",
help="Path to save files generated during training",
)
parser.add_argument(
"--lm-dataset",
type=str,
default="data/lm_training_500/lm_data.pt",
help="LM training data. See local/prepare_lm_training_data.py",
)
parser.add_argument(
"--num-tokens",
type=int,
default="500",
help="BPE model vocab size",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
return parser
@ -88,19 +126,11 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- num_valid_batches: Number of batches of validation data to use each
time we compute validation loss
@ -132,19 +162,12 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
# exp_4, vs. exp_3, is using the Gloam optimizer with
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
# as well as the exponential part.
# exp_6, we change the decay from 0.85 to 0.9.
"exp_dir": Path("conformer_lm/exp_6"),
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
"num_tokens": 5000,
# "exp_dir": Path("conformer_lm/exp_6"),
# "lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
# "num_tokens": 5000,
"blank_sym": 0,
"bos_sym": 1,
"eos_sym": 1,
"start_epoch": 2,
"num_epochs": 20,
"num_valid_batches": 200,
"symbols_per_batch": 5000,
"best_train_loss": float("inf"),
@ -152,7 +175,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
"beam_size": 10,
@ -607,6 +630,8 @@ def run(rank, world_size, args):
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lm_dataset = Path(args.lm_dataset)
world_size = args.world_size
assert world_size >= 1