diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 02efe94fe..0e5291b21 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -62,6 +62,7 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn + from aidatatang_200zh import AIDatatang200zh from aishell import AIShell from asr_datamodule import AsrDataModule @@ -344,8 +345,11 @@ def get_parser(): "--datatang-prob", type=float, default=0.2, - help="The probability to select a batch from the " - "aidatatang_200zh dataset", + help="""The probability to select a batch from the + aidatatang_200zh dataset. + If it is set to 0, you don't need to download the data + for aidatatang_200zh. + """, ) add_model_arguments(parser) @@ -457,8 +461,12 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder = get_decoder_model(params) joiner = get_joiner_model(params) - decoder_datatang = get_decoder_model(params) - joiner_datatang = get_joiner_model(params) + if params.datatang_prob > 0: + decoder_datatang = get_decoder_model(params) + joiner_datatang = get_joiner_model(params) + else: + decoder_datatang = None + joiner_datatang = None model = Transducer( encoder=encoder, @@ -726,7 +734,7 @@ def train_one_epoch( scheduler: LRSchedulerType, graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, - datatang_train_dl: torch.utils.data.DataLoader, + datatang_train_dl: Optional[torch.utils.data.DataLoader], valid_dl: torch.utils.data.DataLoader, rng: random.Random, scaler: GradScaler, @@ -778,13 +786,17 @@ def train_one_epoch( dl_weights = [1 - params.datatang_prob, params.datatang_prob] iter_aishell = iter(train_dl) - iter_datatang = iter(datatang_train_dl) + if datatang_train_dl is not None: + iter_datatang = iter(datatang_train_dl) batch_idx = 0 while True: - idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] - dl = iter_aishell if idx == 0 else iter_datatang + if datatang_train_dl is not None: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_aishell if idx == 0 else iter_datatang + else: + dl = iter_aishell try: batch = next(dl) @@ -808,7 +820,11 @@ def train_one_epoch( warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + if datatang_train_dl is not None: + tot_loss = ( + tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + if aishell: aishell_tot_loss = ( aishell_tot_loss * (1 - 1 / params.reset_interval) @@ -871,12 +887,21 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] + if datatang_train_dl is not None: + datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " + tot_loss_str = ( + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + ) + else: + tot_loss_str = "" + datatang_str = "" + logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, {prefix}_loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"{tot_loss_str}" f"aishell_tot_loss[{aishell_tot_loss}], " - f"datatang_tot_loss[{datatang_tot_loss}], " + f"{datatang_str}" f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) @@ -891,15 +916,18 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + if datatang_train_dl is not None: + # If it is None, tot_loss is the same as aishell_tot_loss. + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) aishell_tot_loss.write_summary( tb_writer, "train/aishell_tot_", params.batch_idx_train ) - datatang_tot_loss.write_summary( - tb_writer, "train/datatang_tot_", params.batch_idx_train - ) + if datatang_train_dl is not None: + datatang_tot_loss.write_summary( + tb_writer, "train/datatang_tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1032,11 +1060,6 @@ def run(rank, world_size, args): train_cuts = aishell.train_cuts() train_cuts = filter_short_and_long_utterances(train_cuts) - datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) - train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) - train_datatang_cuts = train_datatang_cuts.repeat(times=None) - if args.enable_musan: cuts_musan = load_manifest( Path(args.manifest_dir) / "musan_cuts.jsonl.gz" @@ -1052,11 +1075,21 @@ def run(rank, world_size, args): cuts_musan=cuts_musan, ) - datatang_train_dl = asr_datamodule.train_dataloaders( - train_datatang_cuts, - on_the_fly_feats=False, - cuts_musan=cuts_musan, - ) + if params.datatang_prob > 0: + datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) + train_datatang_cuts = datatang.train_cuts() + train_datatang_cuts = filter_short_and_long_utterances( + train_datatang_cuts + ) + train_datatang_cuts = train_datatang_cuts.repeat(times=None) + datatang_train_dl = asr_datamodule.train_dataloaders( + train_datatang_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + else: + datatang_train_dl = None + logging.info("Not using aidatatang_200zh for training") valid_cuts = aishell.valid_cuts() valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) @@ -1065,13 +1098,14 @@ def run(rank, world_size, args): train_dl, # datatang_train_dl ]: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) + if dl is not None: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: @@ -1083,7 +1117,8 @@ def run(rank, world_size, args): scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) - datatang_train_dl.sampler.set_epoch(epoch) + if datatang_train_dl is not None: + datatang_train_dl.sampler.set_epoch(epoch) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)