diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index b1c625df4..ab80d93a7 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -907,7 +907,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(params.seed + params.start_batch) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -1007,7 +1007,7 @@ def run(rank, world_size, args): # to let it know how many tokens we have processed so far, and have a # soft-cutoff lr_tokens measured in tokens. # scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) + fix_random_seed(params.seed + epoch - 1 + params.start_batch) # the above will affect random seeds in the dataloaders. if tb_writer is not None: