diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp index 215c8f671..131f03180 100644 Binary files a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp deleted file mode 100644 index b31384a95..000000000 Binary files a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_bitfit_tta.py.swp and /dev/null differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp new file mode 100644 index 000000000..d227c1db4 Binary files /dev/null and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py index 5521c56c8..9bff19f09 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py @@ -38,6 +38,24 @@ from convolution import ConvolutionModule logger = logging.getLogger().setLevel(logging.INFO) +class LoRAHook(): + def __init__(self, module, embedding_dim, rank, lora_alpha): + self.hook = module.register_forward_hook(self.hook_fn) + self.lora = LoRAModule( + embedding_dim=embedding_dim, + rank=rank, + lora_alpha=lora_alpha, + ) + def hook_fn(self, module, input, output): + lora_out = self.lora(input[0]) + output += lora_out + + def save_checkpoint(self, i, iter_, save_dir): + if isinstance(self.lora, DDP): + lora = self.lora.module + torch.save(lora.state_dict(), f"{save_dir}/lora_{iter_}_{i}.pt") + + class TransformerEncoderAdapter(TransformerEncoder): def __init__(self, args: Wav2Vec2Config): super().__init__(args) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py similarity index 91% rename from egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py rename to egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py index ccd94c641..feeac7699 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_bitfit_tta.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py @@ -101,6 +101,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer from data2vec_encoder import FairSeqData2VecEncoder +from data2vec_audio import LoRAModule, LoRAHook from icefall import diagnostics from icefall.checkpoint import remove_checkpoints @@ -123,8 +124,8 @@ from icefall.utils import ( ) import wandb +import fairseq -#from icefall.checkpoint import save_checkpoint as save_checkpoint_impl LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -138,26 +139,33 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: model.encoder.num_updates = int(batch_count) -def add_adapter_arguments(parser: argparse.ArgumentParser): +def add_pea_arguments(parser: argparse.ArgumentParser): parser.add_argument( - "--add-adapter", + "--adapter", type=str2bool, default=False, help="add adapter to rep model's encoder" ) parser.add_argument( - "--adapter-lr", - type=float, - default=0.0001, - help="adapter learning rate" + "--bitfit", + type=str2bool, + default=False, + help="bias only training for PEA" + ) + + parser.add_argument( + "--lora", + type=str2bool, + default=False, + help="Low Rank Adaptation training for PEA" ) parser.add_argument( - "--gender", - type=str, - default='male', - help="select gender" + "--pea-lr", + type=float, + default=0.0001, + help="PEA learning rate" ) @@ -314,12 +322,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) - parser.add_argument( - "--prompt", - type=str2bool, - default=False, - ) - def get_parser(): parser = argparse.ArgumentParser( @@ -528,10 +530,16 @@ def get_parser(): default=True, help="Whether to use half precision training.", ) - + + parser.add_argument( + "--pea", + type=str2bool, + default=True, + help="Whether to train parameter efficient adaptation", + ) add_model_arguments(parser) add_rep_arguments(parser) - add_adapter_arguments(parser) + add_pea_arguments(parser) return parser @@ -588,7 +596,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 5, + "log_interval": 20, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 # parameters for zipformer @@ -663,7 +671,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) - + model = Transducer( encoder=encoder, decoder=decoder, @@ -672,8 +680,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, - prompt=params.prompt, - sid=params.spk_id, ) return model @@ -714,7 +720,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - elif params.add_adapter: + elif params.pea: filename = params.exp_dir / f"../d2v-base-T.pt" else: return None @@ -727,7 +733,7 @@ def load_checkpoint_if_available( model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - strict=True if not params.add_adapter else False, + strict=True if not params.pea else False, ) keys = [ @@ -1001,6 +1007,7 @@ def train_one_epoch( world_size: int = 1, rank: int = 0, wb = None, + lora_modules = None, ) -> None: """Train the model for one epoch. @@ -1100,54 +1107,23 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return - ''' - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - ''' - if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) del params.cur_batch_idx - ''' - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - ''' + + if rank == 0: + for i, lora in enumerate(lora_modules): + lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir) if batch_idx % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - ''' - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - ''' + if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: @@ -1156,16 +1132,7 @@ def train_one_epoch( f"grad_scale is too small, exiting: {cur_grad_scale}" ) - #if params.batch_idx_train > 4000 and loss > 300 and params.wandb: - # wb.log({"valid/loss": 10000}) - # raise RuntimeError( - # f"divergence... exiting: loss={loss}" - # ) - if batch_idx % (params.log_interval*params.accum_grads) == 0: - #for n, p in model.named_parameters(): - # if 'adapter' in n: - # print(p) if params.multi_optim: cur_enc_lr = scheduler_enc.get_last_lr()[0] cur_dec_lr = scheduler_dec.get_last_lr()[0] @@ -1223,36 +1190,6 @@ def train_one_epoch( wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel}) wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel}) - ''' - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - if wb is not None and rank == 0: - numel = 1 / (params.accum_grads * valid_info["utterances"]) - #wb.log({"valid/loss": valid_info["loss"]*numel}) - wb.log({"valid/loss": numel*(valid_info["simple_loss"] - +valid_info["pruned_loss"] - +valid_info["ctc_loss"] - )}) - wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel}) - wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel}) - wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel}) - ''' loss_value = tot_loss["loss"] / tot_loss["utterances"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -1449,17 +1386,6 @@ def run(rank, world_size, args, wb=None): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - ''' - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - ''' - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") @@ -1506,7 +1432,7 @@ def run(rank, world_size, args, wb=None): cleanup_dist() -def run_adapter(rank, world_size, args, wb=None): +def run_pea(rank, world_size, args, wb=None): """ Args: rank: @@ -1557,8 +1483,7 @@ def run_adapter(rank, world_size, args, wb=None): model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 - #model_avg = copy.deepcopy(model).to(torch.float64) - model_avg = None + model_avg = copy.deepcopy(model).to(torch.float64) assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( @@ -1570,33 +1495,40 @@ def run_adapter(rank, world_size, args, wb=None): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - adapter_names = [] - adapter_param = [] - for n, p in model.named_parameters(): - if 'q_proj.bias' in n or 'fc1.bias' in n: - adapter_names.append(n) - adapter_param.append(p) - else: - p.requires_grad = False - - ''' - if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n: - adapter_names.append(n) - adapter_param.append(p) - elif 'joiner' in n or 'simple' in n or 'ctc' in n: - p.requires_grad = True - else: - p.requires_grad = False - ''' - optimizer_adapter = ScaledAdam( - adapter_param, - lr=params.adapter_lr, - clipping_scale=5.0, - parameters_names=[adapter_names], - ) + lora_modules = [] + for modules in model.modules(): + if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention): + for module in modules.modules(): + if isinstance(module, torch.nn.Linear): + lora_modules.append(LoRAHook( + module, + embedding_dim=args.encoder_dim, + rank=args.rank, + lora_alpha=args.lora_alpha, + )) - scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs) - optimizer, scheduler = optimizer_adapter, scheduler_adapter + if world_size > 1: + logging.info("Using DDP for LoRA") + for module in lora_modules: + module.lora = module.lora.to(device) + module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False) + + pea_names = [] + pea_param = [] + for i, module in enumerate(lora_modules): + for n, p in module.lora.named_parameters(): + new_n = str(i) + n + pea_names.append(new_n) + pea_param.append(p) + + optimizer_pea = ScaledAdam( + pea_param, + lr=params.pea_lr, + clipping_scale=5.0, + parameters_names=[pea_names], + ) + scheduler_pea = Eden(optimizer_pea, 10000, 7) + optimizer, scheduler = optimizer_pea, scheduler_pea librispeech = LibriSpeechAsrDataModule(args) train_cuts = librispeech.vox_cuts(option=params.spk_id) @@ -1611,7 +1543,7 @@ def run_adapter(rank, world_size, args, wb=None): train_dl = librispeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - + valid_cuts = librispeech.dev_clean_cuts(option=params.gender) valid_cuts += librispeech.dev_other_cuts(option=params.gender) valid_dl = librispeech.valid_dataloaders(valid_cuts) @@ -1643,6 +1575,7 @@ def run_adapter(rank, world_size, args, wb=None): world_size=world_size, rank=rank, wb=wb, + lora_modules=lora_modules, ) if params.print_diagnostics: @@ -1746,13 +1679,13 @@ def main(): world_size = args.world_size assert world_size >= 1 if world_size > 1: - mp.spawn(run if not args.add_adapter else run_adapter, + mp.spawn(run if not args.pea else run_pea, args=(world_size, args, wb), nprocs=world_size, join=True ) else: - if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb) + if args.pea: run_pea(rank=0, world_size=1, args=args, wb=wb) else: run(rank=0, world_size=1, args=args, wb=wb) torch.set_num_threads(1)