fix for multigpu

This commit is contained in:
yfyeung 2025-06-18 07:33:15 +00:00
parent 39d90356fe
commit 53111d0e46

View File

@ -693,6 +693,9 @@ def train_one_epoch(
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
if world_size > 1:
torch.distributed.barrier()
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,
@ -710,6 +713,9 @@ def train_one_epoch(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
) )
if world_size > 1:
torch.distributed.barrier()
shave_rate = params.shave_rate shave_rate = params.shave_rate
while True: while True:
try: try:
@ -991,6 +997,10 @@ def run(rank, world_size, args):
client_state={}, client_state={},
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
if world_size > 1:
torch.distributed.barrier()
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,