Use a different seed for each epoch.

This commit is contained in:
Fangjun Kuang 2022-02-21 14:27:26 +08:00
parent 791f54c8c2
commit 407e8aeff7
16 changed files with 16 additions and 16 deletions

View File

@ -625,7 +625,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -620,7 +620,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -564,7 +564,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -618,7 +618,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -691,7 +691,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -768,7 +768,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if (
params.batch_idx_train >= params.use_ali_until

View File

@ -738,7 +738,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -652,7 +652,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -551,7 +551,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -640,7 +640,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -646,7 +646,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -658,7 +658,7 @@ def run(rank, world_size, args):
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate

View File

@ -543,7 +543,7 @@ def run(rank, world_size, args):
valid_dl = timit.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -543,7 +543,7 @@ def run(rank, world_size, args):
valid_dl = timit.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:

View File

@ -527,7 +527,7 @@ def run(rank, world_size, args):
valid_dl = yes_no.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:

View File

@ -539,7 +539,7 @@ def run(rank, world_size, args):
valid_dl = yes_no.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None: