Support using aidatatang_200zh optionally in aishell training (#495)

* Use aidatatang_200zh optionally in aishell training.
This commit is contained in:
Fangjun Kuang 2022-07-26 11:25:01 +08:00 committed by GitHub
parent 4612b03947
commit d3fc4b031e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -62,6 +62,7 @@ import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell from aishell import AIShell
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
@ -344,8 +345,11 @@ def get_parser():
"--datatang-prob", "--datatang-prob",
type=float, type=float,
default=0.2, default=0.2,
help="The probability to select a batch from the " help="""The probability to select a batch from the
"aidatatang_200zh dataset", aidatatang_200zh dataset.
If it is set to 0, you don't need to download the data
for aidatatang_200zh.
""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -457,8 +461,12 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
decoder_datatang = get_decoder_model(params) if params.datatang_prob > 0:
joiner_datatang = get_joiner_model(params) decoder_datatang = get_decoder_model(params)
joiner_datatang = get_joiner_model(params)
else:
decoder_datatang = None
joiner_datatang = None
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
@ -726,7 +734,7 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler, graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
datatang_train_dl: torch.utils.data.DataLoader, datatang_train_dl: Optional[torch.utils.data.DataLoader],
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random, rng: random.Random,
scaler: GradScaler, scaler: GradScaler,
@ -778,13 +786,17 @@ def train_one_epoch(
dl_weights = [1 - params.datatang_prob, params.datatang_prob] dl_weights = [1 - params.datatang_prob, params.datatang_prob]
iter_aishell = iter(train_dl) iter_aishell = iter(train_dl)
iter_datatang = iter(datatang_train_dl) if datatang_train_dl is not None:
iter_datatang = iter(datatang_train_dl)
batch_idx = 0 batch_idx = 0
while True: while True:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] if datatang_train_dl is not None:
dl = iter_aishell if idx == 0 else iter_datatang idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_aishell if idx == 0 else iter_datatang
else:
dl = iter_aishell
try: try:
batch = next(dl) batch = next(dl)
@ -808,7 +820,11 @@ def train_one_epoch(
warmup=(params.batch_idx_train / params.model_warm_step), warmup=(params.batch_idx_train / params.model_warm_step),
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info if datatang_train_dl is not None:
tot_loss = (
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
if aishell: if aishell:
aishell_tot_loss = ( aishell_tot_loss = (
aishell_tot_loss * (1 - 1 / params.reset_interval) aishell_tot_loss * (1 - 1 / params.reset_interval)
@ -871,12 +887,21 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
tot_loss_str = (
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
)
else:
tot_loss_str = ""
datatang_str = ""
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, {prefix}_loss[{loss_info}], " f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"{tot_loss_str}"
f"aishell_tot_loss[{aishell_tot_loss}], " f"aishell_tot_loss[{aishell_tot_loss}], "
f"datatang_tot_loss[{datatang_tot_loss}], " f"{datatang_str}"
f"batch size: {batch_size}, " f"batch size: {batch_size}, "
f"lr: {cur_lr:.2e}" f"lr: {cur_lr:.2e}"
) )
@ -891,15 +916,18 @@ def train_one_epoch(
f"train/current_{prefix}_", f"train/current_{prefix}_",
params.batch_idx_train, params.batch_idx_train,
) )
tot_loss.write_summary( if datatang_train_dl is not None:
tb_writer, "train/tot_", params.batch_idx_train # If it is None, tot_loss is the same as aishell_tot_loss.
) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
aishell_tot_loss.write_summary( aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train tb_writer, "train/aishell_tot_", params.batch_idx_train
) )
datatang_tot_loss.write_summary( if datatang_train_dl is not None:
tb_writer, "train/datatang_tot_", params.batch_idx_train datatang_tot_loss.write_summary(
) tb_writer, "train/datatang_tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -1032,11 +1060,6 @@ def run(rank, world_size, args):
train_cuts = aishell.train_cuts() train_cuts = aishell.train_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts) train_cuts = filter_short_and_long_utterances(train_cuts)
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest( cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz" Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
@ -1052,11 +1075,21 @@ def run(rank, world_size, args):
cuts_musan=cuts_musan, cuts_musan=cuts_musan,
) )
datatang_train_dl = asr_datamodule.train_dataloaders( if params.datatang_prob > 0:
train_datatang_cuts, datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
on_the_fly_feats=False, train_datatang_cuts = datatang.train_cuts()
cuts_musan=cuts_musan, train_datatang_cuts = filter_short_and_long_utterances(
) train_datatang_cuts
)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
else:
datatang_train_dl = None
logging.info("Not using aidatatang_200zh for training")
valid_cuts = aishell.valid_cuts() valid_cuts = aishell.valid_cuts()
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
@ -1065,13 +1098,14 @@ def run(rank, world_size, args):
train_dl, train_dl,
# datatang_train_dl # datatang_train_dl
]: ]:
scan_pessimistic_batches_for_oom( if dl is not None:
model=model, scan_pessimistic_batches_for_oom(
train_dl=dl, model=model,
optimizer=optimizer, train_dl=dl,
graph_compiler=graph_compiler, optimizer=optimizer,
params=params, graph_compiler=graph_compiler,
) params=params,
)
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -1083,7 +1117,8 @@ def run(rank, world_size, args):
scheduler.step_epoch(epoch - 1) scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1) train_dl.sampler.set_epoch(epoch - 1)
datatang_train_dl.sampler.set_epoch(epoch) if datatang_train_dl is not None:
datatang_train_dl.sampler.set_epoch(epoch)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)