mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix computing averaged loss in the aishell recipe. (#523)
* Fix computing averaged loss in the aishell recipe. * Set find_unused_parameters optionally.
This commit is contained in:
parent
f24b76e64b
commit
5149788cb2
@ -22,8 +22,12 @@
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
./prepare.sh
|
./prepare.sh
|
||||||
|
|
||||||
|
# If you use a non-zero value for --datatang-prob, you also need to run
|
||||||
./prepare_aidatatang_200zh.sh
|
./prepare_aidatatang_200zh.sh
|
||||||
|
|
||||||
|
If you use --datatang-prob=0, then you don't need to run the above script.
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +66,6 @@ import optim
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from aidatatang_200zh import AIDatatang200zh
|
from aidatatang_200zh import AIDatatang200zh
|
||||||
from aishell import AIShell
|
from aishell import AIShell
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
@ -344,7 +347,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--datatang-prob",
|
"--datatang-prob",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.2,
|
default=0.0,
|
||||||
help="""The probability to select a batch from the
|
help="""The probability to select a batch from the
|
||||||
aidatatang_200zh dataset.
|
aidatatang_200zh dataset.
|
||||||
If it is set to 0, you don't need to download the data
|
If it is set to 0, you don't need to download the data
|
||||||
@ -945,7 +948,10 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
if datatang_train_dl is not None:
|
||||||
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
|
else:
|
||||||
|
loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
@ -1032,7 +1038,16 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
if params.datatang_prob > 0:
|
||||||
|
find_unused_parameters = True
|
||||||
|
else:
|
||||||
|
find_unused_parameters = False
|
||||||
|
|
||||||
|
model = DDP(
|
||||||
|
model,
|
||||||
|
device_ids=[rank],
|
||||||
|
find_unused_parameters=find_unused_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user