mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
fix for multigpu
This commit is contained in:
parent
39d90356fe
commit
53111d0e46
@ -693,6 +693,9 @@ def train_one_epoch(
|
|||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
@ -710,6 +713,9 @@ def train_one_epoch(
|
|||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
shave_rate = params.shave_rate
|
shave_rate = params.shave_rate
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -991,6 +997,10 @@ def run(rank, world_size, args):
|
|||||||
client_state={},
|
client_state={},
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user