add sampler state_dict

This commit is contained in:
Yuekai Zhang 2024-06-14 09:41:08 +08:00
parent 7db5445d1e
commit 618b686166
2 changed files with 11 additions and 1 deletions

View File

@ -588,7 +588,7 @@ def main():
test_sets_cuts = multi_dataset.aishell_test_cuts()
elif params.dataset == "speechio":
test_sets_cuts = multi_dataset.speechio_test_cuts()
elif params.dataaset == "wenetspeech_test_meeting":
elif params.dataset == "wenetspeech_test_meeting":
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
else:
test_sets_cuts = multi_dataset.test_cuts()

View File

@ -190,6 +190,14 @@ def get_parser():
""",
)
parser.add_argument(
"--sampler-state-dict-path",
type=str,
default=None,
help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict.
""",
)
parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate."
)
@ -813,6 +821,8 @@ def run(rank, world_size, args):
# else:
# sampler_state_dict = None
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict