mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add sampler state_dict
This commit is contained in:
parent
7db5445d1e
commit
618b686166
@ -588,7 +588,7 @@ def main():
|
|||||||
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||||
elif params.dataset == "speechio":
|
elif params.dataset == "speechio":
|
||||||
test_sets_cuts = multi_dataset.speechio_test_cuts()
|
test_sets_cuts = multi_dataset.speechio_test_cuts()
|
||||||
elif params.dataaset == "wenetspeech_test_meeting":
|
elif params.dataset == "wenetspeech_test_meeting":
|
||||||
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
||||||
else:
|
else:
|
||||||
test_sets_cuts = multi_dataset.test_cuts()
|
test_sets_cuts = multi_dataset.test_cuts()
|
||||||
|
@ -190,6 +190,14 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampler-state-dict-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||||
)
|
)
|
||||||
@ -813,6 +821,8 @@ def run(rank, world_size, args):
|
|||||||
# else:
|
# else:
|
||||||
# sampler_state_dict = None
|
# sampler_state_dict = None
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
if params.sampler_state_dict_path:
|
||||||
|
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||||
# TODO: load sampler state dict
|
# TODO: load sampler state dict
|
||||||
train_dl = data_module.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
|
Loading…
x
Reference in New Issue
Block a user