Reset seed at the beginning of each epoch. (#221)

* Reset seed at the beginning of each epoch.

* Use a different seed for each epoch.
This commit is contained in:
Fangjun Kuang 2022-02-21 15:16:39 +08:00 committed by GitHub
parent cbf8c18ebd
commit 1c35ae1dba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 144 additions and 16 deletions

View File

@ -121,6 +121,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -555,7 +562,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -618,6 +625,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -124,6 +124,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -546,7 +553,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -613,6 +620,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -92,6 +92,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -507,7 +514,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -557,6 +564,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -129,6 +129,13 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -534,7 +541,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -611,6 +618,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -140,6 +140,13 @@ def get_parser():
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -580,7 +587,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -684,6 +691,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -109,6 +109,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -673,7 +680,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -761,6 +768,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if (
params.batch_idx_train >= params.use_ali_until

View File

@ -179,6 +179,13 @@ def get_parser():
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -642,7 +649,7 @@ def run(rank, world_size, args):
params.valid_interval = 800
params.warm_step = 30000
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -731,6 +738,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -138,6 +138,13 @@ def get_parser():
help="Proportion of samples trained with short right context",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -575,7 +582,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -645,6 +652,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -95,6 +95,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -486,7 +493,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -544,6 +551,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -130,6 +130,13 @@ def get_parser():
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -544,7 +551,7 @@ def run(rank, world_size, args):
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -633,6 +640,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -131,6 +131,13 @@ def get_parser():
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -548,7 +555,7 @@ def run(rank, world_size, args):
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -639,6 +646,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -149,6 +149,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -562,7 +569,7 @@ def run(rank, world_size, args):
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -651,6 +658,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -95,6 +95,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -486,7 +493,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -536,6 +543,7 @@ def run(rank, world_size, args):
valid_dl = timit.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -95,6 +95,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -486,7 +493,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -536,6 +543,7 @@ def run(rank, world_size, args):
valid_dl = timit.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -71,6 +71,13 @@ def get_parser():
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -468,7 +475,7 @@ def run(rank, world_size, args):
params.update(vars(args))
params["env_info"] = get_env_info()
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -520,6 +527,7 @@ def run(rank, world_size, args):
valid_dl = yes_no.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:

View File

@ -114,6 +114,13 @@ def get_parser():
help="Directory to save results",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -487,7 +494,7 @@ def run(rank, world_size, args):
params.update(vars(args))
params["env_info"] = get_env_info()
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -532,6 +539,7 @@ def run(rank, world_size, args):
valid_dl = yes_no.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None: