mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Reset seed at the beginning of each epoch. (#221)
* Reset seed at the beginning of each epoch. * Use a different seed for each epoch.
This commit is contained in:
parent
cbf8c18ebd
commit
1c35ae1dba
@ -121,6 +121,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -555,7 +562,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -618,6 +625,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -124,6 +124,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -546,7 +553,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -613,6 +620,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -92,6 +92,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -507,7 +514,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -557,6 +564,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if epoch > params.start_epoch:
|
if epoch > params.start_epoch:
|
||||||
|
@ -129,6 +129,13 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -534,7 +541,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -611,6 +618,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -140,6 +140,13 @@ def get_parser():
|
|||||||
help="The lr_factor for Noam optimizer",
|
help="The lr_factor for Noam optimizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -580,7 +587,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -684,6 +691,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -109,6 +109,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -673,7 +680,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -761,6 +768,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
if (
|
if (
|
||||||
params.batch_idx_train >= params.use_ali_until
|
params.batch_idx_train >= params.use_ali_until
|
||||||
|
@ -179,6 +179,13 @@ def get_parser():
|
|||||||
"with this parameter before adding to the final loss.",
|
"with this parameter before adding to the final loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -642,7 +649,7 @@ def run(rank, world_size, args):
|
|||||||
params.valid_interval = 800
|
params.valid_interval = 800
|
||||||
params.warm_step = 30000
|
params.warm_step = 30000
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -731,6 +738,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -138,6 +138,13 @@ def get_parser():
|
|||||||
help="Proportion of samples trained with short right context",
|
help="Proportion of samples trained with short right context",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -575,7 +582,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -645,6 +652,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -95,6 +95,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -486,7 +493,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -544,6 +551,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if epoch > params.start_epoch:
|
if epoch > params.start_epoch:
|
||||||
|
@ -130,6 +130,13 @@ def get_parser():
|
|||||||
help="The lr_factor for Noam optimizer",
|
help="The lr_factor for Noam optimizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -544,7 +551,7 @@ def run(rank, world_size, args):
|
|||||||
params.valid_interval = 800
|
params.valid_interval = 800
|
||||||
params.warm_step = 8000
|
params.warm_step = 8000
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -633,6 +640,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -131,6 +131,13 @@ def get_parser():
|
|||||||
help="The lr_factor for Noam optimizer",
|
help="The lr_factor for Noam optimizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -548,7 +555,7 @@ def run(rank, world_size, args):
|
|||||||
params.valid_interval = 800
|
params.valid_interval = 800
|
||||||
params.warm_step = 8000
|
params.warm_step = 8000
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -639,6 +646,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -149,6 +149,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -562,7 +569,7 @@ def run(rank, world_size, args):
|
|||||||
params.valid_interval = 800
|
params.valid_interval = 800
|
||||||
params.warm_step = 8000
|
params.warm_step = 8000
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -651,6 +658,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
@ -95,6 +95,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -486,7 +493,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -536,6 +543,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = timit.valid_dataloaders()
|
valid_dl = timit.valid_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if epoch > params.start_epoch:
|
if epoch > params.start_epoch:
|
||||||
|
@ -95,6 +95,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -486,7 +493,7 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -536,6 +543,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = timit.valid_dataloaders()
|
valid_dl = timit.valid_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if epoch > params.start_epoch:
|
if epoch > params.start_epoch:
|
||||||
|
@ -71,6 +71,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -468,7 +475,7 @@ def run(rank, world_size, args):
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params["env_info"] = get_env_info()
|
params["env_info"] = get_env_info()
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -520,6 +527,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = yes_no.test_dataloaders()
|
valid_dl = yes_no.test_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
|
@ -114,6 +114,13 @@ def get_parser():
|
|||||||
help="Directory to save results",
|
help="Directory to save results",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -487,7 +494,7 @@ def run(rank, world_size, args):
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params["env_info"] = get_env_info()
|
params["env_info"] = get_env_info()
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
@ -532,6 +539,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = yes_no.test_dataloaders()
|
valid_dl = yes_no.test_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user