fix max_duration

This commit is contained in:
Yuekai Zhang 2024-06-14 09:43:52 +08:00
parent 618b686166
commit d1e31c7ac7

View File

@ -823,6 +823,7 @@ def run(rank, world_size, args):
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict