mOve where srand called

This commit is contained in:
Daniel Povey 2023-05-19 16:43:21 +08:00
parent f37ec0f0da
commit 7d162bf41e

View File

@ -907,7 +907,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(params.seed + params.start_batch) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -1007,7 +1007,7 @@ def run(rank, world_size, args):
# to let it know how many tokens we have processed so far, and have a # to let it know how many tokens we have processed so far, and have a
# soft-cutoff lr_tokens measured in tokens. # soft-cutoff lr_tokens measured in tokens.
# scheduler.step_epoch(epoch - 1) # scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1 + params.start_batch)
# the above will affect random seeds in the dataloaders. # the above will affect random seeds in the dataloaders.
if tb_writer is not None: if tb_writer is not None: