mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Use aidatatang_200zh optionally in aishell training.
This commit is contained in:
parent
4612b03947
commit
b34eafa500
@ -62,6 +62,7 @@ import optim
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from aidatatang_200zh import AIDatatang200zh
|
from aidatatang_200zh import AIDatatang200zh
|
||||||
from aishell import AIShell
|
from aishell import AIShell
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
@ -344,8 +345,11 @@ def get_parser():
|
|||||||
"--datatang-prob",
|
"--datatang-prob",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.2,
|
default=0.2,
|
||||||
help="The probability to select a batch from the "
|
help="""The probability to select a batch from the
|
||||||
"aidatatang_200zh dataset",
|
aidatatang_200zh dataset.
|
||||||
|
If it is set to 0, you don't need to download the data
|
||||||
|
for aidatatang_200zh.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -778,13 +782,17 @@ def train_one_epoch(
|
|||||||
dl_weights = [1 - params.datatang_prob, params.datatang_prob]
|
dl_weights = [1 - params.datatang_prob, params.datatang_prob]
|
||||||
|
|
||||||
iter_aishell = iter(train_dl)
|
iter_aishell = iter(train_dl)
|
||||||
iter_datatang = iter(datatang_train_dl)
|
if datatang_train_dl is not None:
|
||||||
|
iter_datatang = iter(datatang_train_dl)
|
||||||
|
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
if datatang_train_dl is not None:
|
||||||
dl = iter_aishell if idx == 0 else iter_datatang
|
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
||||||
|
dl = iter_aishell if idx == 0 else iter_datatang
|
||||||
|
else:
|
||||||
|
dl = iter_aishell
|
||||||
|
|
||||||
try:
|
try:
|
||||||
batch = next(dl)
|
batch = next(dl)
|
||||||
@ -1032,11 +1040,6 @@ def run(rank, world_size, args):
|
|||||||
train_cuts = aishell.train_cuts()
|
train_cuts = aishell.train_cuts()
|
||||||
train_cuts = filter_short_and_long_utterances(train_cuts)
|
train_cuts = filter_short_and_long_utterances(train_cuts)
|
||||||
|
|
||||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
|
||||||
train_datatang_cuts = datatang.train_cuts()
|
|
||||||
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
|
|
||||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
|
||||||
|
|
||||||
if args.enable_musan:
|
if args.enable_musan:
|
||||||
cuts_musan = load_manifest(
|
cuts_musan = load_manifest(
|
||||||
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||||
@ -1052,11 +1055,21 @@ def run(rank, world_size, args):
|
|||||||
cuts_musan=cuts_musan,
|
cuts_musan=cuts_musan,
|
||||||
)
|
)
|
||||||
|
|
||||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
if params.datatang_prob > 0:
|
||||||
train_datatang_cuts,
|
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||||
on_the_fly_feats=False,
|
train_datatang_cuts = datatang.train_cuts()
|
||||||
cuts_musan=cuts_musan,
|
train_datatang_cuts = filter_short_and_long_utterances(
|
||||||
)
|
train_datatang_cuts
|
||||||
|
)
|
||||||
|
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||||
|
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||||
|
train_datatang_cuts,
|
||||||
|
on_the_fly_feats=False,
|
||||||
|
cuts_musan=cuts_musan,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
datatang_train_dl = None
|
||||||
|
logging.info("Not using aidatatang_200zh for training")
|
||||||
|
|
||||||
valid_cuts = aishell.valid_cuts()
|
valid_cuts = aishell.valid_cuts()
|
||||||
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
||||||
@ -1065,13 +1078,14 @@ def run(rank, world_size, args):
|
|||||||
train_dl,
|
train_dl,
|
||||||
# datatang_train_dl
|
# datatang_train_dl
|
||||||
]:
|
]:
|
||||||
scan_pessimistic_batches_for_oom(
|
if dl is not None:
|
||||||
model=model,
|
scan_pessimistic_batches_for_oom(
|
||||||
train_dl=dl,
|
model=model,
|
||||||
optimizer=optimizer,
|
train_dl=dl,
|
||||||
graph_compiler=graph_compiler,
|
optimizer=optimizer,
|
||||||
params=params,
|
graph_compiler=graph_compiler,
|
||||||
)
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -1083,7 +1097,8 @@ def run(rank, world_size, args):
|
|||||||
scheduler.step_epoch(epoch - 1)
|
scheduler.step_epoch(epoch - 1)
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
datatang_train_dl.sampler.set_epoch(epoch)
|
if datatang_train_dl is not None:
|
||||||
|
datatang_train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user