From af58c4c409915e552315755d9defaf5c48d9891d Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Mon, 26 Dec 2022 13:40:35 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 73728 -> 86016 bytes .../train.py | 258 ++++++++++++++++++ 2 files changed, 258 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 95d986a107ecb7ae00a02b9637bacdd2165fd420..78395208d23cc43f33b9ad7d99afa118606ccd47 100644 GIT binary patch delta 1460 zcmY+@dn{FP0KoCz;ofp_Ju`*bos~upR#u}dcOeO_sBFE))l{Pl#bQIH^+#hl=w&1S zh+3xOwP_)*)>dh)2lMDPvo$75BUXJ6w?8`D=l*{8ocs8l^F1fOLz7>nbxHI;xI5Yq zxY1v*IS5gfI`iXeDkK8ezG zs6r)7@PrFysQnhrC`U4m;s~7Kgf_BOzyvj6JJdU)6f%xtgZ%xIhuq)kE9Z{rlm|cK zbV|1|SF@O&3n}{4JDKz+8Pu(n%q^%z8vNl0ci3T)yzRJx90XySoYg4BVYtH$b5!hs zf|uv5WGu%u?1Bqy&_%Tzm=J_O%v0Rn`XfaQr z0jMSGkcQLCNq>sTHrV0~sa~QNcaeo-*a7vCw!sw1#xR0mR3I7tn4#DYG@t~D*o4)X zrBoko;U>;wZ7?t1&_aW6r0R#7uN>zP2P3>-iE&b?xh5s~Q|BsqGE%3Mj-}|7(eJHG zJ)5~SMO-$ur{7{tR!TGVQC6Tfy@yMPM>y2SdOaL5OH1|9Zbli-A`ZdWj7?YxR|pKz zxCh;O5i7cw)Sw(WID{xzLx(@Ko<;sT7=3IkSRhQUlDf#Dhup9bPiAnpL-Iv_3tVl^OU v2jZVVtp|X38xYS2;^{zK0>s5YYzxF7qj{!3RAiLd-sjC&#Im{mS)(NYe 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + logging.info(model) + exit() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.multi_optim: + logging.info("Using seperate optimizers over encoder, decoder ...") + + enc_param = [] + enc_names = [] + + dec_names = [] + dec_param = [] + + for n, p in model.named_parameters(): + name = n.split('.')[1] + if name == 'encoder' and 'feature_extractor' not in n: + enc_names.append(n) + enc_param.append(p) + elif 'ctc_output' in n: + enc_names.append(n) + enc_param.append(p) + elif 'feature_extractor' not in n: + dec_names.append(n) + dec_param.append(p) + + optimizer_enc = ScaledAdam( + enc_param, + lr=params.peak_enc_lr, + clipping_scale=None, + parameters_names=[enc_names], + ) + optimizer_dec = ScaledAdam( + dec_param, + lr=params.peak_dec_lr, + clipping_scale=5.0, + parameters_names=[dec_names], + ) + + scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs) + scheduler_dec = Eden(optimizer_dec, params.lr_batches, params.lr_epochs) + optimizer = [optimizer_enc, optimizer_dec] + scheduler = [scheduler_enc, scheduler_dec] + + else: + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + + logging.info(f"len name = {len(parameters_names)}") + logging.info(f"len param = {len(list(model.parameters()))}") + + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and ("optimizer" in checkpoints or "optimizer_enc" in checkpoints): + if params.multi_optim: + logging.info("Loading optimizer state dict") + optimizer_enc.load_state_dict(checkpoints["optimizer_enc"]) + optimizer_dec.load_state_dict(checkpoints["optimizer_dec"]) + + else: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if checkpoints: + if ( + params.multi_optim + and "scheduler_enc" in checkpoints + and checkpoints["scheduler_enc"] is not None + ): + logging.info("Loading enc/dec scheduler state dict") + scheduler_enc.load_state_dict(checkpoints["scheduler_enc"]) + scheduler_dec.load_state_dict(checkpoints["scheduler_dec"]) + else: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + ''' + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + ''' + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + if params.multi_optim: + scheduler_enc.step_epoch(epoch - 1) + scheduler_dec.step_epoch(epoch - 1) + else: + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + wb=wb, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + def display_and_save_batch( batch: dict, params: AttributeDict,