mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Ready to train a masked LM with vocab size 500.
This commit is contained in:
parent
7c3ab28a68
commit
bc122eec64
@ -33,7 +33,13 @@ class LmDataset(torch.utils.data.Dataset):
|
|||||||
# self.sentences[i] returns a 1-D tensor containing word indexes
|
# self.sentences[i] returns a 1-D tensor containing word indexes
|
||||||
# self.words[self.sentences[i]] returns a ragged tensor with axes
|
# self.words[self.sentences[i]] returns a ragged tensor with axes
|
||||||
# [word][token].
|
# [word][token].
|
||||||
return self.words[self.sentences[i]].values.tolist()
|
word_tokens = self.words[self.sentences[i]]
|
||||||
|
# TODO(fangjun): Need to figure out why `word_tokens`
|
||||||
|
# can be a torch.Tensor
|
||||||
|
if isinstance(word_tokens, torch.Tensor):
|
||||||
|
return word_tokens
|
||||||
|
else:
|
||||||
|
return word_tokens.values.tolist()
|
||||||
|
|
||||||
|
|
||||||
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||||
|
@ -74,6 +74,44 @@ def get_parser():
|
|||||||
help="Should various information be logged in tensorboard.",
|
help="Should various information be logged in tensorboard.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_lm/exp",
|
||||||
|
help="Path to save files generated during training",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-dataset",
|
||||||
|
type=str,
|
||||||
|
default="data/lm_training_500/lm_data.pt",
|
||||||
|
help="LM training data. See local/prepare_lm_training_data.py",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-tokens",
|
||||||
|
type=int,
|
||||||
|
default="500",
|
||||||
|
help="BPE model vocab size",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="Number of epochs to train.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-epoch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""Resume training from from this epoch.
|
||||||
|
If it is positive, it will load checkpoint from
|
||||||
|
exp_dir/epoch-{start_epoch-1}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -88,19 +126,11 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
Explanation of options saved in `params`:
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
- exp_dir: It specifies the directory where all training related
|
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
|
||||||
|
|
||||||
- lr: It specifies the initial learning rate
|
- lr: It specifies the initial learning rate
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
in computing features.
|
in computing features.
|
||||||
|
|
||||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
|
||||||
and continue training from that checkpoint.
|
|
||||||
|
|
||||||
- num_epochs: Number of epochs to train.
|
|
||||||
|
|
||||||
- num_valid_batches: Number of batches of validation data to use each
|
- num_valid_batches: Number of batches of validation data to use each
|
||||||
time we compute validation loss
|
time we compute validation loss
|
||||||
|
|
||||||
@ -132,19 +162,12 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
|
# "exp_dir": Path("conformer_lm/exp_6"),
|
||||||
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
# "lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||||
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
# "num_tokens": 5000,
|
||||||
# as well as the exponential part.
|
|
||||||
# exp_6, we change the decay from 0.85 to 0.9.
|
|
||||||
"exp_dir": Path("conformer_lm/exp_6"),
|
|
||||||
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
|
||||||
"num_tokens": 5000,
|
|
||||||
"blank_sym": 0,
|
"blank_sym": 0,
|
||||||
"bos_sym": 1,
|
"bos_sym": 1,
|
||||||
"eos_sym": 1,
|
"eos_sym": 1,
|
||||||
"start_epoch": 2,
|
|
||||||
"num_epochs": 20,
|
|
||||||
"num_valid_batches": 200,
|
"num_valid_batches": 200,
|
||||||
"symbols_per_batch": 5000,
|
"symbols_per_batch": 5000,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
@ -152,7 +175,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 10,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000,
|
"valid_interval": 3000,
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
@ -607,6 +630,8 @@ def run(rank, world_size, args):
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lm_dataset = Path(args.lm_dataset)
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user