adapt for S, M and L training subset

This commit is contained in:
luomingshuang 2022-04-27 13:43:54 +08:00
parent 0930748b61
commit 4b567e480f
3 changed files with 46 additions and 9 deletions

View File

@ -198,6 +198,13 @@ class WenetSpeechAsrDataModule:
help="lazily open CutSets to avoid OOM (for L|XL subset)", help="lazily open CutSets to avoid OOM (for L|XL subset)",
) )
group.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for using",
)
def train_dataloaders( def train_dataloaders(
self, self,
cuts_train: CutSet, cuts_train: CutSet,
@ -316,6 +323,7 @@ class WenetSpeechAsrDataModule:
if sampler_state_dict is not None: if sampler_state_dict is not None:
logging.info("Loading sampler state dict") logging.info("Loading sampler state dict")
print(sampler_state_dict.keys())
train_sampler.load_state_dict(sampler_state_dict) train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have # 'seed' is derived from the current random state, which will have
@ -396,7 +404,6 @@ class WenetSpeechAsrDataModule:
world_size=1, world_size=1,
shuffle=False, shuffle=False,
) )
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper( test_iter_dataset = IterableDatasetWrapper(
@ -417,14 +424,12 @@ class WenetSpeechAsrDataModule:
logging.info("use lazy cuts") logging.info("use lazy cuts")
cuts_train = CutSet.from_jsonl_lazy( cuts_train = CutSet.from_jsonl_lazy(
self.args.manifest_dir self.args.manifest_dir
/ "cuts_L.jsonl.gz" / f"cuts_{self.args.training_subset}.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
) )
else: else:
cuts_train = CutSet.from_file( cuts_train = CutSet.from_file(
self.args.manifest_dir self.args.manifest_dir
/ "cuts_L.jsonl.gz" / f"cuts_{self.args.training_subset}.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
) )
return cuts_train return cuts_train

View File

@ -101,6 +101,15 @@ def get_parser():
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument(
"--batch",
type=int,
default=None,
help="It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
@ -499,6 +508,11 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1: elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
elif params.batch is not None:
filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt"
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints([filenames], device=device))
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
filenames = [] filenames = []

View File

@ -251,7 +251,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=200, default=8000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -279,6 +279,26 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--valid-interval",
type=int,
default=3000,
help="""When training_subset is L, set the valid_interval to 3000.
When training_subset is M, set the valid_interval to 1000.
When training_subset is S, set the valid_interval to 400.
""",
)
parser.add_argument(
"--model-warm-step",
type=int,
default=3000,
help="""When training_subset is L, set the model_warm_step to 3000.
When training_subset is M, set the model_warm_step to 500.
When training_subset is S, set the model_warm_step to 100.
""",
)
return parser return parser
@ -333,9 +353,8 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 1, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
@ -348,7 +367,6 @@ def get_params() -> AttributeDict:
# parameters for joiner # parameters for joiner
"joiner_dim": 512, "joiner_dim": 512,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )