diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 2ab051e83..d4c02df5a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -479,18 +479,18 @@ class LibriSpeechAsrDataModule: @lru_cache() def gigaspeech_subset_small_cuts(self) -> CutSet: logging.info("About to get Gigaspeech subset-S cuts") - return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") @lru_cache() def gigaspeech_dev_cuts(self) -> CutSet: logging.info("About to get Gigaspeech dev cuts") return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + self.args.manifest_dir / "cuts_DEV.jsonl.gz" ) @lru_cache() def gigaspeech_test_cuts(self) -> CutSet: logging.info("About to get Gigaspeech test cuts") return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + self.args.manifest_dir / "cuts_TEST.jsonl.gz" ) diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 8d3f7c68d..9b0a8312b 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -134,7 +134,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): default=True, help="If true, finetune from a pre-trained checkpoint", ) - + parser.add_argument( "--use-mux", type=str2bool, @@ -390,7 +390,7 @@ def get_parser(): parser.add_argument( "--base-lr", type=float, - default=0.0045, + default=0.045, help="""The base learning rate. It is set to a very small value as we are doing fine-tuning""", ) @@ -646,6 +646,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), + use_lora=params.use_lora, + lora_r=params.lora_r if params.use_lora else 0, ) return encoder @@ -1041,6 +1043,12 @@ def train_one_epoch( saved_bad_model = False + for name, m in model.named_modules(): + if "lora" in name: + m.training = True + else: + m.training = False + def save_bad_model(suffix: str = ""): save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", @@ -1177,7 +1185,6 @@ def train_one_epoch( valid_dl=valid_dl, world_size=world_size, ) - model.train() logging.info( f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" ) @@ -1188,6 +1195,7 @@ def train_one_epoch( valid_info.write_summary( tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train ) + model.train() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -1257,7 +1265,7 @@ def run(rank, world_size, args): assert params.start_epoch == 1, "Fine-tune must start from epoch 1" modules = params.init_modules.split(",") if params.init_modules else None checkpoints = load_model_params( - ckpt=params.finetune_ckpt, model=model, init_modules=modules + ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False ) # Need to update the model_avg if use initialisation if rank == 0: @@ -1270,6 +1278,17 @@ def run(rank, world_size, args): params=params, model=model, model_avg=model_avg ) + # keep the original model untouched, only update the adapters + num_trainable = 0 + for name, p in model.named_parameters(): + if "lora_A" in name or "lora_B" in name: + p.requires_grad = True + num_trainable += p.numel() + else: + p.requires_grad = False + + logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100)) + model.to(device) if world_size > 1: logging.info("Using DDP") @@ -1379,14 +1398,14 @@ def run(rank, world_size, args): librispeech.valid_dataloaders(gigaspeech_dev_cuts), ] - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/zipformer_lora/model.py b/egs/librispeech/ASR/zipformer_lora/model.py new file mode 120000 index 000000000..0c6fe6112 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/model.py @@ -0,0 +1 @@ +../zipformer/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/optim.py b/egs/librispeech/ASR/zipformer_lora/optim.py new file mode 120000 index 000000000..207eecfcd --- /dev/null +++ b/egs/librispeech/ASR/zipformer_lora/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index c0f1e3087..7aeb25721 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -23,6 +23,7 @@ import random import torch import math import torch.nn as nn +import torch.nn.functional as F from torch import Tensor @@ -517,6 +518,94 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans +class LoRALayer: + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + +class ScaledLinear_lora(nn.Linear, LoRALayer): + def __init__( + self, + in_features: int, + out_features: int, + r: int=0, + fan_in_fan_out: bool=False, + lora_alpha: int=1, + lora_dropout: float=0.0, + initial_scale: float = 1.0, + merge_weights: bool = True, + **kwargs, + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.initial_scale = initial_scale + self.fan_in_fan_out = fan_in_fan_out + if r > 0: + self.lora_A = nn.Parameter(torch.full((r, in_features), 0.0)) + self.lora_B = nn.Parameter(torch.full((out_features, r), 0.0)) + self.scaling = self.lora_alpha / self.r + self.weight.requires_grad = False + + self.reset_parameters() + + def reset_parameters(self): + # initialize the parameters + nn.Linear.reset_parameters(self) + if hasattr(self, "lora_A"): + initial_scale = self.initial_scale + with torch.no_grad(): + self.weight[:] *= initial_scale + if self.bias is not None: + nn.init.uniform_(self.bias, -0.1 * initial_scale, 0.1 * initial_scale) + if hasattr(self, 'lora_A'): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def train(self, mode: bool=True): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + nn.Linear.train(self, mode) + if mode: + # We don't want the weights to be merged in training mode + if self.merge_weights and self.merged: + if self.r > 0: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + else: + # When evaluating the model, we merge the weights for simplicity + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + delta_result = self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1) + return result + delta_result * self.scaling + else: + return F.linear(x, T(self.weight), bias=self.bias) def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index 61ae378d8..09f027d75 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -30,6 +30,7 @@ from scaling import ( ) from scaling import ( ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + ScaledLinear_lora ) from scaling import ( ActivationDropoutAndLinear, @@ -116,6 +117,8 @@ class Zipformer2(EncoderInterface): causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], + use_lora: bool = True, + lora_r: int = 0, ) -> None: super(Zipformer2, self).__init__() @@ -152,6 +155,8 @@ class Zipformer2(EncoderInterface): self.chunk_size = chunk_size self.left_context_frames = left_context_frames + self.lora_r = lora_r if use_lora else 0 + for u, d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d @@ -171,6 +176,7 @@ class Zipformer2(EncoderInterface): dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], causal=causal, + lora_r=self.lora_r, ) # For the segment of the warmup period, we let the Conv2dSubsampling @@ -589,6 +595,9 @@ class Zipformer2EncoderLayer(nn.Module): bypass_skip_rate: FloatLike = ScheduledFloat( (0.0, 0.5), (4000.0, 0.02), default=0 ), + lora_r: int = 0, + lora_alpha: int = 4, + lora_dropout: float = 0.0, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -620,6 +629,9 @@ class Zipformer2EncoderLayer(nn.Module): query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, dropout=0.0, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, ) self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) @@ -1508,6 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): dropout: dropout probability for attn_output_weights. Default: 0.0. pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on any given call to forward(), in training time. + lora_r: the bottleneck dimension of LoRA """ def __init__( @@ -1519,6 +1532,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_head_dim: int, dropout: float = 0.0, pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + lora_r: int = 0, + lora_alpha: int = 4, + lora_dropout: float=0.0 ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1537,8 +1553,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # dividing it between the query and key. Note: this module is intended # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + # self.in_proj = ScaledLinear( + # embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + # ) + self.in_proj = ScaledLinear_lora( + in_features=embed_dim, + out_features=in_proj_dim, + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + initial_scale=query_head_dim**-0.25, + bias=True, ) self.whiten_keys = Whiten(