mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Support using aidatatang_200zh optionally in aishell training (#495)
* Use aidatatang_200zh optionally in aishell training.
This commit is contained in:
parent
4612b03947
commit
d3fc4b031e
@ -62,6 +62,7 @@ import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from aidatatang_200zh import AIDatatang200zh
|
||||
from aishell import AIShell
|
||||
from asr_datamodule import AsrDataModule
|
||||
@ -344,8 +345,11 @@ def get_parser():
|
||||
"--datatang-prob",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="The probability to select a batch from the "
|
||||
"aidatatang_200zh dataset",
|
||||
help="""The probability to select a batch from the
|
||||
aidatatang_200zh dataset.
|
||||
If it is set to 0, you don't need to download the data
|
||||
for aidatatang_200zh.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -457,8 +461,12 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
decoder_datatang = get_decoder_model(params)
|
||||
joiner_datatang = get_joiner_model(params)
|
||||
if params.datatang_prob > 0:
|
||||
decoder_datatang = get_decoder_model(params)
|
||||
joiner_datatang = get_joiner_model(params)
|
||||
else:
|
||||
decoder_datatang = None
|
||||
joiner_datatang = None
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
@ -726,7 +734,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
datatang_train_dl: torch.utils.data.DataLoader,
|
||||
datatang_train_dl: Optional[torch.utils.data.DataLoader],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
@ -778,13 +786,17 @@ def train_one_epoch(
|
||||
dl_weights = [1 - params.datatang_prob, params.datatang_prob]
|
||||
|
||||
iter_aishell = iter(train_dl)
|
||||
iter_datatang = iter(datatang_train_dl)
|
||||
if datatang_train_dl is not None:
|
||||
iter_datatang = iter(datatang_train_dl)
|
||||
|
||||
batch_idx = 0
|
||||
|
||||
while True:
|
||||
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
||||
dl = iter_aishell if idx == 0 else iter_datatang
|
||||
if datatang_train_dl is not None:
|
||||
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
||||
dl = iter_aishell if idx == 0 else iter_datatang
|
||||
else:
|
||||
dl = iter_aishell
|
||||
|
||||
try:
|
||||
batch = next(dl)
|
||||
@ -808,7 +820,11 @@ def train_one_epoch(
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
if datatang_train_dl is not None:
|
||||
tot_loss = (
|
||||
tot_loss * (1 - 1 / params.reset_interval)
|
||||
) + loss_info
|
||||
|
||||
if aishell:
|
||||
aishell_tot_loss = (
|
||||
aishell_tot_loss * (1 - 1 / params.reset_interval)
|
||||
@ -871,12 +887,21 @@ def train_one_epoch(
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
if datatang_train_dl is not None:
|
||||
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||
tot_loss_str = (
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
)
|
||||
else:
|
||||
tot_loss_str = ""
|
||||
datatang_str = ""
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"{tot_loss_str}"
|
||||
f"aishell_tot_loss[{aishell_tot_loss}], "
|
||||
f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||
f"{datatang_str}"
|
||||
f"batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}"
|
||||
)
|
||||
@ -891,15 +916,18 @@ def train_one_epoch(
|
||||
f"train/current_{prefix}_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
if datatang_train_dl is not None:
|
||||
# If it is None, tot_loss is the same as aishell_tot_loss.
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
aishell_tot_loss.write_summary(
|
||||
tb_writer, "train/aishell_tot_", params.batch_idx_train
|
||||
)
|
||||
datatang_tot_loss.write_summary(
|
||||
tb_writer, "train/datatang_tot_", params.batch_idx_train
|
||||
)
|
||||
if datatang_train_dl is not None:
|
||||
datatang_tot_loss.write_summary(
|
||||
tb_writer, "train/datatang_tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -1032,11 +1060,6 @@ def run(rank, world_size, args):
|
||||
train_cuts = aishell.train_cuts()
|
||||
train_cuts = filter_short_and_long_utterances(train_cuts)
|
||||
|
||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||
train_datatang_cuts = datatang.train_cuts()
|
||||
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
|
||||
if args.enable_musan:
|
||||
cuts_musan = load_manifest(
|
||||
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||
@ -1052,11 +1075,21 @@ def run(rank, world_size, args):
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
|
||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||
train_datatang_cuts,
|
||||
on_the_fly_feats=False,
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
if params.datatang_prob > 0:
|
||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||
train_datatang_cuts = datatang.train_cuts()
|
||||
train_datatang_cuts = filter_short_and_long_utterances(
|
||||
train_datatang_cuts
|
||||
)
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||
train_datatang_cuts,
|
||||
on_the_fly_feats=False,
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
else:
|
||||
datatang_train_dl = None
|
||||
logging.info("Not using aidatatang_200zh for training")
|
||||
|
||||
valid_cuts = aishell.valid_cuts()
|
||||
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
||||
@ -1065,13 +1098,14 @@ def run(rank, world_size, args):
|
||||
train_dl,
|
||||
# datatang_train_dl
|
||||
]:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=dl,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
if dl is not None:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=dl,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -1083,7 +1117,8 @@ def run(rank, world_size, args):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
datatang_train_dl.sampler.set_epoch(epoch)
|
||||
if datatang_train_dl is not None:
|
||||
datatang_train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
Loading…
x
Reference in New Issue
Block a user