fix for multigpu

This commit is contained in:
yfyeung 2025-06-18 07:33:15 +00:00
parent 39d90356fe
commit 53111d0e46

View File

@ -693,6 +693,9 @@ def train_one_epoch(
exclude_frozen_parameters=True,
)
if world_size > 1:
torch.distributed.barrier()
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
@ -710,6 +713,9 @@ def train_one_epoch(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
if world_size > 1:
torch.distributed.barrier()
shave_rate = params.shave_rate
while True:
try:
@ -991,6 +997,10 @@ def run(rank, world_size, args):
client_state={},
exclude_frozen_parameters=True,
)
if world_size > 1:
torch.distributed.barrier()
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,