From 2e3ff0b31fc57451036f6325b5c3e12d8f164171 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 4 Jan 2023 10:51:36 +0800 Subject: [PATCH] update --- .../train.py | 101 ++++++++++-------- 1 file changed, 55 insertions(+), 46 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 537d5deca..564551fe0 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -166,6 +166,41 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of entries in the memory for the Emformer", ) + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=8, + help="On which encoder layer to perform KD" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=16, + help="Number of codebooks" + ) + + # distillation related args + parser.add_argument( + "--distil-delta", + type=int, + default=None, + help="Offset when doing KD" + ) + + parser.add_argument( + "--codebook-loss-scale", + type=float, + default=0.1, + help="The scale of codebook loss.", + ) def get_parser(): parser = argparse.ArgumentParser( @@ -358,41 +393,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--enable-distillation", - type=str2bool, - default=True, - help="Whether to eanble distillation.", - ) - - parser.add_argument( - "--distillation-layer", - type=int, - default=8, - help="On which encoder layer to perform KD" - ) - - parser.add_argument( - "--num-codebooks", - type=int, - default=16, - help="Number of codebooks" - ) - - parser.add_argument( - "--distil-delta", - type=int, - default=None, - help="Offset when doing KD" - ) - - parser.add_argument( - "--codebook-loss-scale", - type=float, - default=0.1, - help="The scale of codebook loss.", - ) - add_model_arguments(parser) return parser @@ -444,6 +444,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -652,6 +653,9 @@ def extract_codebook_indexes(batch): cuts_pre_mixed = [ c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts ] + for cut in cuts_pre_mixed: + cb = cut.codebook_indexes + print(f"All cuts have codebook indexes") codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -969,6 +973,11 @@ def run(rank, world_size, args): setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") + # Note: it's better to set --spec-aug-time-warpi-factor=-1 + # when doing distillation with vq. + if params.enable_distillation: + assert args.spec_aug_time_warp_factor < 1, "You need to disable time warp in MVQ KD" + if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: @@ -1034,10 +1043,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1067,14 +1076,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: