diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.prompt_tuning.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.prompt_tuning.py.swp new file mode 100644 index 000000000..30bc235e8 Binary files /dev/null and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.prompt_tuning.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py index f79b08d7a..173404859 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py @@ -153,6 +153,13 @@ def add_adapter_arguments(parser: argparse.ArgumentParser): help="adapter learning rate" ) + parser.add_argument( + "--gender", + type=str, + default='male', + help="select gender" + ) + def add_rep_arguments(parser: argparse.ArgumentParser): parser.add_argument( @@ -161,6 +168,13 @@ def add_rep_arguments(parser: argparse.ArgumentParser): default=True, help="Use wandb for MLOps", ) + parser.add_argument( + "--hpo", + type=str2bool, + default=False, + help="Use small db for HPO", + ) + parser.add_argument( "--accum-grads", type=int, @@ -286,14 +300,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--decoder-dim", type=int, - default=512, + default=768, help="Embedding dimension in the decoder model.", ) parser.add_argument( "--joiner-dim", type=int, - default=512, + default=768, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. @@ -333,6 +347,13 @@ def get_parser(): default=30, help="Number of epochs to train.", ) + + parser.add_argument( + "--num-updates", + type=int, + default=5000, + help="Number of epochs to train.", + ) parser.add_argument( "--start-epoch", @@ -461,7 +482,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=2000, + default=200, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -485,7 +506,7 @@ def get_parser(): parser.add_argument( "--average-period", type=int, - default=200, + default=10, help="""Update the averaged model, namely `model_avg`, after processing this number of batches. `model_avg` is a separate version of model, in which each floating-point parameter is the average of all the @@ -561,7 +582,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 50, + "log_interval": 20, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 # parameters for zipformer @@ -570,7 +591,8 @@ def get_params() -> AttributeDict: # parameters for ctc loss "beam_size": 10, "use_double_scores": True, - "warm_step": 4000, + "warm_step": 0, + #"warm_step": 4000, #"warm_step": 3000, "env_info": get_env_info(), } @@ -685,7 +707,7 @@ def load_checkpoint_if_available( elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" elif params.add_adapter: - filename = params.exp_dir / f"d2v-base-T.pt" + filename = params.exp_dir / f"../d2v-base-T.pt" else: return None @@ -717,6 +739,8 @@ def load_checkpoint_if_available( if "cur_batch_idx" in saved_params: params["cur_batch_idx"] = saved_params["cur_batch_idx"] + params.batch_idx_train = 0 + return saved_params @@ -818,6 +842,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] + token_ids = sp.encode(texts, out_type=int) y = k2.RaggedTensor(token_ids).to(device) @@ -888,12 +913,21 @@ def compute_loss( if decode: model.eval() with torch.no_grad(): - hypos = model.module.decode( - x=feature, - x_lens=feature_lens, - y=y, - sp=sp - ) + try: + hypos = model.module.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + except: + hypos = model.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + logging.info(f'ref: {batch["supervisions"]["text"][0]}') logging.info(f'hyp: {" ".join(hypos[0])}') model.train() @@ -1002,6 +1036,8 @@ def train_one_epoch( scheduler_enc, scheduler_dec = scheduler[0], scheduler[1] for batch_idx, batch in enumerate(train_dl): + if params.batch_idx_train > params.num_updates: + break if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx @@ -1019,7 +1055,9 @@ def train_one_epoch( is_training=True, decode = True if batch_idx % params.decode_interval == 0 else False, ) - loss_info.reduce(loss.device) + + try: loss_info.reduce(loss.device) + except: pass numel = params.world_size / (params.accum_grads * loss_info["utterances"]) loss *= numel ## normalize loss over utts(batch size) @@ -1053,7 +1091,8 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return - + + ''' if ( rank == 0 and params.batch_idx_train > 0 @@ -1064,6 +1103,7 @@ def train_one_epoch( model_cur=model, model_avg=model_avg, ) + ''' if ( params.batch_idx_train > 0 @@ -1083,11 +1123,13 @@ def train_one_epoch( rank=rank, ) del params.cur_batch_idx + ''' remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, ) + ''' if batch_idx % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval @@ -1106,13 +1148,16 @@ def train_one_epoch( f"grad_scale is too small, exiting: {cur_grad_scale}" ) - if params.batch_idx_train > 4000 and loss > 300 and params.wandb: - wb.log({"valid/loss": 10000}) - raise RunteimError( - f"divergence... exiting: loss={loss}" - ) + #if params.batch_idx_train > 4000 and loss > 300 and params.wandb: + # wb.log({"valid/loss": 10000}) + # raise RuntimeError( + # f"divergence... exiting: loss={loss}" + # ) if batch_idx % (params.log_interval*params.accum_grads) == 0: + #for n, p in model.named_parameters(): + # if 'adapter' in n: + # print(p) if params.multi_optim: cur_enc_lr = scheduler_enc.get_last_lr()[0] cur_dec_lr = scheduler_dec.get_last_lr()[0] @@ -1169,7 +1214,8 @@ def train_one_epoch( wb.log({"train/simple_loss": loss_info["simple_loss"]*numel}) wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel}) wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel}) - + + ''' logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1190,11 +1236,15 @@ def train_one_epoch( if wb is not None and rank == 0: numel = 1 / (params.accum_grads * valid_info["utterances"]) - wb.log({"valid/loss": valid_info["loss"]*numel}) + #wb.log({"valid/loss": valid_info["loss"]*numel}) + wb.log({"valid/loss": numel*(valid_info["simple_loss"] + +valid_info["pruned_loss"] + +valid_info["ctc_loss"] + )}) wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel}) wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel}) wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel}) - + ''' loss_value = tot_loss["loss"] / tot_loss["utterances"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -1247,7 +1297,6 @@ def run(rank, world_size, args, wb=None): logging.info("About to create model") model = get_transducer_model(params) logging.info(model) - exit() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1441,17 +1490,20 @@ def run(rank, world_size, args, wb=None): if params.print_diagnostics: diagnostic.print_diagnostics() break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) + + ''' + if epoch % 50 == 0: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + ''' logging.info("Done!") @@ -1526,46 +1578,72 @@ def run_adapter(rank, world_size, args, wb=None): adapter_names = [] adapter_param = [] for n, p in model.named_parameters(): - if 'adapters' in n: + if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n: adapter_names.append(n) adapter_param.append(p) - #else: - # p.requires_grad = False + elif 'joiner' in n or 'simple' in n or 'ctc' in n: + p.requires_grad = True + else: + p.requires_grad = False + optimizer_adapter = ScaledAdam( adapter_param, lr=params.adapter_lr, clipping_scale=5.0, parameters_names=[adapter_names], ) - scheduler_adapter = Eden(optimizer_adapter, 5000, 3.5) #params.lr_batche, params.lr_epochs) + scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs) optimizer, scheduler = optimizer_adapter, scheduler_adapter librispeech = LibriSpeechAsrDataModule(args) + + ''' + if params.hpo: + train_cuts = librispeech.train_clean_10_cuts(option=params.gender) + else: + train_cuts = librispeech.train_clean_100_cuts(option=params.gender) + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts(option=params.gender) + train_cuts += librispeech.train_other_500_cuts(option=params.gender) + ''' - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() - + #train_cuts = librispeech.train_clean_10_cuts(option='male') + #train_cuts = librispeech.test_clean_user(option='big') + train_cuts = librispeech.vox_cuts(option=params.spk_id) + def remove_short_and_long_utt(c: Cut): return 1.0 <= c.duration <= 20.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) - + sampler_state_dict = None train_dl = librispeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) + #train_dl = librispeech.test_dataloaders( + # train_cuts + #) + + ''' + print('\n'*5) + print('-'*30) + for batch in train_dl: + print(batch) + print('-'*30) + print('\n'*5) + exit() + ''' - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() + valid_cuts = librispeech.dev_clean_cuts(option=params.gender) + valid_cuts += librispeech.dev_other_cuts(option=params.gender) valid_dl = librispeech.valid_dataloaders(valid_cuts) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"update num : {params.batch_idx_train}") scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -1594,17 +1672,20 @@ def run_adapter(rank, world_size, args, wb=None): if params.print_diagnostics: diagnostic.print_diagnostics() break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) + + ''' + if epoch % 10 == 0: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + ''' logging.info("Done!") @@ -1691,13 +1772,13 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - #args.exp_dir = args.exp_dir + str(random.randint(0,400)) + if args.wandb: args.exp_dir = args.exp_dir + str(random.randint(0,400)) args.exp_dir = Path(args.exp_dir) logging.info("save arguments to config.yaml...") save_args(args) - - if args.wandb: wb = wandb.init(project="d2v-T", entity="dohe0342", config=vars(args)) + + if args.wandb: wb = wandb.init(project="d2v-adapter", entity="dohe0342", config=vars(args)) else: wb = None world_size = args.world_size @@ -1709,7 +1790,7 @@ def main(): join=True ) else: - if not args.add_adapter: run(rank=0, world_size=1, args=args, wb=wb) + if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb) else: run(rank=0, world_size=1, args=args, wb=wb) torch.set_num_threads(1)