mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
first version of lora, could just run
This commit is contained in:
parent
87b19c0bd2
commit
a56034799b
@ -479,18 +479,18 @@ class LibriSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def gigaspeech_subset_small_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech subset-S cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def gigaspeech_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
|
||||
self.args.manifest_dir / "cuts_DEV.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def gigaspeech_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
|
||||
self.args.manifest_dir / "cuts_TEST.jsonl.gz"
|
||||
)
|
||||
|
@ -134,7 +134,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
||||
default=True,
|
||||
help="If true, finetune from a pre-trained checkpoint",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-mux",
|
||||
type=str2bool,
|
||||
@ -390,7 +390,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--base-lr",
|
||||
type=float,
|
||||
default=0.0045,
|
||||
default=0.045,
|
||||
help="""The base learning rate.
|
||||
It is set to a very small value as we are doing fine-tuning""",
|
||||
)
|
||||
@ -646,6 +646,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
use_lora=params.use_lora,
|
||||
lora_r=params.lora_r if params.use_lora else 0,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -1041,6 +1043,12 @@ def train_one_epoch(
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if "lora" in name:
|
||||
m.training = True
|
||||
else:
|
||||
m.training = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint_impl(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
@ -1177,7 +1185,6 @@ def train_one_epoch(
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
|
||||
)
|
||||
@ -1188,6 +1195,7 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
|
||||
)
|
||||
model.train()
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
@ -1257,7 +1265,7 @@ def run(rank, world_size, args):
|
||||
assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
|
||||
modules = params.init_modules.split(",") if params.init_modules else None
|
||||
checkpoints = load_model_params(
|
||||
ckpt=params.finetune_ckpt, model=model, init_modules=modules
|
||||
ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False
|
||||
)
|
||||
# Need to update the model_avg if use initialisation
|
||||
if rank == 0:
|
||||
@ -1270,6 +1278,17 @@ def run(rank, world_size, args):
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
# keep the original model untouched, only update the adapters
|
||||
num_trainable = 0
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_A" in name or "lora_B" in name:
|
||||
p.requires_grad = True
|
||||
num_trainable += p.numel()
|
||||
else:
|
||||
p.requires_grad = False
|
||||
|
||||
logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100))
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
@ -1379,14 +1398,14 @@ def run(rank, world_size, args):
|
||||
librispeech.valid_dataloaders(gigaspeech_dev_cuts),
|
||||
]
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
1
egs/librispeech/ASR/zipformer_lora/model.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/model.py
|
1
egs/librispeech/ASR/zipformer_lora/optim.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/optim.py
|
@ -23,6 +23,7 @@ import random
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@ -517,6 +518,94 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
||||
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
||||
return ans
|
||||
|
||||
class LoRALayer:
|
||||
def __init__(
|
||||
self,
|
||||
r: int,
|
||||
lora_alpha: int,
|
||||
lora_dropout: float,
|
||||
merge_weights: bool,
|
||||
):
|
||||
self.r = r
|
||||
self.lora_alpha = lora_alpha
|
||||
# Optional dropout
|
||||
if lora_dropout > 0.:
|
||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||
else:
|
||||
self.lora_dropout = lambda x: x
|
||||
# Mark the weight as unmerged
|
||||
self.merged = False
|
||||
self.merge_weights = merge_weights
|
||||
|
||||
class ScaledLinear_lora(nn.Linear, LoRALayer):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int=0,
|
||||
fan_in_fan_out: bool=False,
|
||||
lora_alpha: int=1,
|
||||
lora_dropout: float=0.0,
|
||||
initial_scale: float = 1.0,
|
||||
merge_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
|
||||
self.initial_scale = initial_scale
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(torch.full((r, in_features), 0.0))
|
||||
self.lora_B = nn.Parameter(torch.full((out_features, r), 0.0))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
# initialize the parameters
|
||||
nn.Linear.reset_parameters(self)
|
||||
if hasattr(self, "lora_A"):
|
||||
initial_scale = self.initial_scale
|
||||
with torch.no_grad():
|
||||
self.weight[:] *= initial_scale
|
||||
if self.bias is not None:
|
||||
nn.init.uniform_(self.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize B the same way as the default for nn.Linear and A to zero
|
||||
# this is different than what is described in the paper but should not affect performance
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool=True):
|
||||
def T(w):
|
||||
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
||||
nn.Linear.train(self, mode)
|
||||
if mode:
|
||||
# We don't want the weights to be merged in training mode
|
||||
if self.merge_weights and self.merged:
|
||||
if self.r > 0:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
else:
|
||||
# When evaluating the model, we merge the weights for simplicity
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def T(w):
|
||||
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
||||
if self.r > 0 and not self.merged:
|
||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||
delta_result = self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)
|
||||
return result + delta_result * self.scaling
|
||||
else:
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
|
||||
"""
|
||||
|
@ -30,6 +30,7 @@ from scaling import (
|
||||
)
|
||||
from scaling import (
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
ScaledLinear_lora
|
||||
)
|
||||
from scaling import (
|
||||
ActivationDropoutAndLinear,
|
||||
@ -116,6 +117,8 @@ class Zipformer2(EncoderInterface):
|
||||
causal: bool = False,
|
||||
chunk_size: Tuple[int] = [-1],
|
||||
left_context_frames: Tuple[int] = [-1],
|
||||
use_lora: bool = True,
|
||||
lora_r: int = 0,
|
||||
) -> None:
|
||||
super(Zipformer2, self).__init__()
|
||||
|
||||
@ -152,6 +155,8 @@ class Zipformer2(EncoderInterface):
|
||||
self.chunk_size = chunk_size
|
||||
self.left_context_frames = left_context_frames
|
||||
|
||||
self.lora_r = lora_r if use_lora else 0
|
||||
|
||||
for u, d in zip(encoder_unmasked_dim, encoder_dim):
|
||||
assert u <= d
|
||||
|
||||
@ -171,6 +176,7 @@ class Zipformer2(EncoderInterface):
|
||||
dropout=dropout,
|
||||
cnn_module_kernel=cnn_module_kernel[i],
|
||||
causal=causal,
|
||||
lora_r=self.lora_r,
|
||||
)
|
||||
|
||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||
@ -589,6 +595,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
bypass_skip_rate: FloatLike = ScheduledFloat(
|
||||
(0.0, 0.5), (4000.0, 0.02), default=0
|
||||
),
|
||||
lora_r: int = 0,
|
||||
lora_alpha: int = 4,
|
||||
lora_dropout: float = 0.0,
|
||||
) -> None:
|
||||
super(Zipformer2EncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -620,6 +629,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
query_head_dim=query_head_dim,
|
||||
pos_head_dim=pos_head_dim,
|
||||
dropout=0.0,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
|
||||
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
||||
@ -1508,6 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
||||
any given call to forward(), in training time.
|
||||
lora_r: the bottleneck dimension of LoRA
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1519,6 +1532,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
pos_head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
||||
lora_r: int = 0,
|
||||
lora_alpha: int = 4,
|
||||
lora_dropout: float=0.0
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -1537,8 +1553,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# dividing it between the query and key. Note: this module is intended
|
||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||
# it would be necessary to apply the scaling factor in the forward function.
|
||||
self.in_proj = ScaledLinear(
|
||||
embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
# self.in_proj = ScaledLinear(
|
||||
# embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
# )
|
||||
self.in_proj = ScaledLinear_lora(
|
||||
in_features=embed_dim,
|
||||
out_features=in_proj_dim,
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
initial_scale=query_head_dim**-0.25,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.whiten_keys = Whiten(
|
||||
|
Loading…
x
Reference in New Issue
Block a user