diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_uda.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_uda.py.swp index ffaeb529d..d7dba7dd2 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_uda.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_uda.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_uda.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_uda.py index 094af65c7..f8b5e89fe 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_uda.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_uda.py @@ -1663,6 +1663,186 @@ def run_adapter(rank, world_size, args, wb=None): cleanup_dist() +def run_adapter_uda(rank, world_size, args, wb=None): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + adapter_names = [] + adapter_param = [] + for n, p in model.named_parameters(): + if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n: + adapter_names.append(n) + adapter_param.append(p) + elif 'joiner' in n or 'simple' in n or 'ctc' in n: + p.requires_grad = True + else: + p.requires_grad = False + + optimizer_adapter = ScaledAdam( + adapter_param, + lr=params.adapter_lr, + clipping_scale=5.0, + parameters_names=[adapter_names], + ) + scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs) + + optimizer, scheduler = optimizer_adapter, scheduler_adapter + + librispeech = LibriSpeechAsrDataModule(args) + + ''' + if params.hpo: + train_cuts = librispeech.train_clean_10_cuts(option=params.gender) + else: + train_cuts = librispeech.train_clean_100_cuts(option=params.gender) + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts(option=params.gender) + train_cuts += librispeech.train_other_500_cuts(option=params.gender) + ''' + + #train_cuts = librispeech.train_clean_10_cuts(option='male') + #train_cuts = librispeech.test_clean_user(option='big') + train_cuts = librispeech.vox_cuts(option=params.spk_id) + + def remove_short_and_long_utt(c: Cut): + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + #train_dl = librispeech.test_dataloaders( + # train_cuts + #) + + ''' + print('\n'*5) + print('-'*30) + for batch in train_dl: + print(batch) + print('-'*30) + print('\n'*5) + exit() + ''' + + valid_cuts = librispeech.dev_clean_cuts(option=params.gender) + valid_cuts += librispeech.dev_other_cuts(option=params.gender) + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + wb=wb, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % 10 == 0: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + + def display_and_save_batch( batch: dict, params: AttributeDict,