Use aidatatang_200zh optionally in aishell training.

This commit is contained in:
Fangjun Kuang 2022-07-26 10:55:56 +08:00
parent 4612b03947
commit b34eafa500

View File

@ -62,6 +62,7 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell
from asr_datamodule import AsrDataModule
@ -344,8 +345,11 @@ def get_parser():
"--datatang-prob",
type=float,
default=0.2,
help="The probability to select a batch from the "
"aidatatang_200zh dataset",
help="""The probability to select a batch from the
aidatatang_200zh dataset.
If it is set to 0, you don't need to download the data
for aidatatang_200zh.
""",
)
add_model_arguments(parser)
@ -778,13 +782,17 @@ def train_one_epoch(
dl_weights = [1 - params.datatang_prob, params.datatang_prob]
iter_aishell = iter(train_dl)
iter_datatang = iter(datatang_train_dl)
if datatang_train_dl is not None:
iter_datatang = iter(datatang_train_dl)
batch_idx = 0
while True:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_aishell if idx == 0 else iter_datatang
if datatang_train_dl is not None:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_aishell if idx == 0 else iter_datatang
else:
dl = iter_aishell
try:
batch = next(dl)
@ -1032,11 +1040,6 @@ def run(rank, world_size, args):
train_cuts = aishell.train_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts)
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan:
cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
@ -1052,11 +1055,21 @@ def run(rank, world_size, args):
cuts_musan=cuts_musan,
)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(
train_datatang_cuts
)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
else:
datatang_train_dl = None
logging.info("Not using aidatatang_200zh for training")
valid_cuts = aishell.valid_cuts()
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
@ -1065,13 +1078,14 @@ def run(rank, world_size, args):
train_dl,
# datatang_train_dl
]:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
if dl is not None:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
@ -1083,7 +1097,8 @@ def run(rank, world_size, args):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
datatang_train_dl.sampler.set_epoch(epoch)
if datatang_train_dl is not None:
datatang_train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)