diff --git a/egs/librispeech/ASR/.bitfit.sh.swp b/egs/librispeech/ASR/.bitfit.sh.swp index db2719e5c..5861c8cef 100644 Binary files a/egs/librispeech/ASR/.bitfit.sh.swp and b/egs/librispeech/ASR/.bitfit.sh.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.decode.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.decode.py.swp new file mode 100644 index 000000000..fe30ed021 Binary files /dev/null and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.decode.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp index 10a5113d2..b31384a95 100644 Binary files a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py index 849a061c1..ccd94c641 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py @@ -1274,6 +1274,252 @@ def run(rank, world_size, args, wb=None): """ params = get_params() params.update(vars(args)) + #params.warm_step *= params.accum_grads + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + logging.info(model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.multi_optim: + logging.info("Using seperate optimizers over encoder, decoder ...") + + enc_param = [] + enc_names = [] + + dec_names = [] + dec_param = [] + + for n, p in model.named_parameters(): + name = n.split('.')[1] + if name == 'encoder' and 'feature_extractor' not in n: + enc_names.append(n) + enc_param.append(p) + elif 'ctc_output' in n: + enc_names.append(n) + enc_param.append(p) + elif 'feature_extractor' not in n: + dec_names.append(n) + dec_param.append(p) + + optimizer_enc = ScaledAdam( + enc_param, + lr=params.peak_enc_lr, + clipping_scale=None, + parameters_names=[enc_names], + ) + optimizer_dec = ScaledAdam( + dec_param, + lr=params.peak_dec_lr, + clipping_scale=5.0, + parameters_names=[dec_names], + ) + + scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs) + scheduler_dec = Eden(optimizer_dec, params.lr_batches, params.lr_epochs) + optimizer = [optimizer_enc, optimizer_dec] + scheduler = [scheduler_enc, scheduler_dec] + + else: + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + + logging.info(f"len name = {len(parameters_names)}") + logging.info(f"len param = {len(list(model.parameters()))}") + + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and ("optimizer" in checkpoints or "optimizer_enc" in checkpoints): + if params.multi_optim: + logging.info("Loading optimizer state dict") + optimizer_enc.load_state_dict(checkpoints["optimizer_enc"]) + optimizer_dec.load_state_dict(checkpoints["optimizer_dec"]) + + else: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if checkpoints: + if ( + params.multi_optim + and "scheduler_enc" in checkpoints + and checkpoints["scheduler_enc"] is not None + ): + logging.info("Loading enc/dec scheduler state dict") + scheduler_enc.load_state_dict(checkpoints["scheduler_enc"]) + scheduler_dec.load_state_dict(checkpoints["scheduler_dec"]) + else: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + ''' + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + ''' + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + if params.multi_optim: + scheduler_enc.step_epoch(epoch - 1) + scheduler_dec.step_epoch(epoch - 1) + else: + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + wb=wb, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def run_adapter(rank, world_size, args, wb=None): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) fix_random_seed(params.seed) if world_size > 1: @@ -1349,34 +1595,10 @@ def run(rank, world_size, args, wb=None): parameters_names=[adapter_names], ) - #for n, p in model.named_parameters(): - # p.requires_grad = False - - #prompt = torch.randn((100, 512), requires_grad=True) - #optimizer_adapter = ScaledAdam( - # [model.prompt], - # lr=params.adapter_lr, - # clipping_scale=5.0, - # parameters_names=['P'], - #) - scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs) optimizer, scheduler = optimizer_adapter, scheduler_adapter librispeech = LibriSpeechAsrDataModule(args) - - ''' - if params.hpo: - train_cuts = librispeech.train_clean_10_cuts(option=params.gender) - else: - train_cuts = librispeech.train_clean_100_cuts(option=params.gender) - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts(option=params.gender) - train_cuts += librispeech.train_other_500_cuts(option=params.gender) - ''' - - #train_cuts = librispeech.train_clean_10_cuts(option='male') - #train_cuts = librispeech.test_clean_user(option='big') train_cuts = librispeech.vox_cuts(option=params.spk_id) def remove_short_and_long_utt(c: Cut): @@ -1389,19 +1611,6 @@ def run(rank, world_size, args, wb=None): train_dl = librispeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - #train_dl = librispeech.test_dataloaders( - # train_cuts - #) - - ''' - print('\n'*5) - print('-'*30) - for batch in train_dl: - print(batch) - print('-'*30) - print('\n'*5) - exit() - ''' valid_cuts = librispeech.dev_clean_cuts(option=params.gender) valid_cuts += librispeech.dev_other_cuts(option=params.gender) @@ -1440,20 +1649,6 @@ def run(rank, world_size, args, wb=None): diagnostic.print_diagnostics() break - ''' - if epoch % 10 == 0: - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - ''' - logging.info("Done!") if world_size > 1: