mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix for multigpu
This commit is contained in:
parent
39d90356fe
commit
53111d0e46
@ -693,6 +693,9 @@ def train_one_epoch(
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
@ -710,6 +713,9 @@ def train_one_epoch(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
shave_rate = params.shave_rate
|
||||
while True:
|
||||
try:
|
||||
@ -991,6 +997,10 @@ def run(rank, world_size, args):
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
|
Loading…
x
Reference in New Issue
Block a user