Fix computing averaged loss in the aishell recipe.

This commit is contained in:
Fangjun Kuang 2022-08-06 12:25:11 +08:00
parent 1f7832b93c
commit 540020bdc8

View File

@ -22,8 +22,12 @@
Usage: Usage:
./prepare.sh ./prepare.sh
# If you use a non-zero value for --datatang-prob, you also need to run
./prepare_aidatatang_200zh.sh ./prepare_aidatatang_200zh.sh
If you use --datatang-prob=0, then you don't need to run the above script.
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
@ -62,7 +66,6 @@ import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell from aishell import AIShell
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
@ -344,7 +347,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--datatang-prob", "--datatang-prob",
type=float, type=float,
default=0.2, default=0.0,
help="""The probability to select a batch from the help="""The probability to select a batch from the
aidatatang_200zh dataset. aidatatang_200zh dataset.
If it is set to 0, you don't need to download the data If it is set to 0, you don't need to download the data
@ -945,7 +948,10 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
loss_value = tot_loss["loss"] / tot_loss["frames"] if datatang_train_dl is not None:
loss_value = tot_loss["loss"] / tot_loss["frames"]
else:
loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch