first version of lora, could just run

This commit is contained in:
marcoyang 2024-03-07 12:28:39 +08:00
parent 87b19c0bd2
commit a56034799b
6 changed files with 152 additions and 17 deletions

View File

@ -479,18 +479,18 @@ class LibriSpeechAsrDataModule:
@lru_cache()
def gigaspeech_subset_small_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech subset-S cuts")
return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
self.args.manifest_dir / "cuts_DEV.jsonl.gz"
)
@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
self.args.manifest_dir / "cuts_TEST.jsonl.gz"
)

View File

@ -390,7 +390,7 @@ def get_parser():
parser.add_argument(
"--base-lr",
type=float,
default=0.0045,
default=0.045,
help="""The base learning rate.
It is set to a very small value as we are doing fine-tuning""",
)
@ -646,6 +646,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
use_lora=params.use_lora,
lora_r=params.lora_r if params.use_lora else 0,
)
return encoder
@ -1041,6 +1043,12 @@ def train_one_epoch(
saved_bad_model = False
for name, m in model.named_modules():
if "lora" in name:
m.training = True
else:
m.training = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
@ -1177,7 +1185,6 @@ def train_one_epoch(
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
)
@ -1188,6 +1195,7 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
)
model.train()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -1257,7 +1265,7 @@ def run(rank, world_size, args):
assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
modules = params.init_modules.split(",") if params.init_modules else None
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False
)
# Need to update the model_avg if use initialisation
if rank == 0:
@ -1270,6 +1278,17 @@ def run(rank, world_size, args):
params=params, model=model, model_avg=model_avg
)
# keep the original model untouched, only update the adapters
num_trainable = 0
for name, p in model.named_parameters():
if "lora_A" in name or "lora_B" in name:
p.requires_grad = True
num_trainable += p.numel()
else:
p.requires_grad = False
logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100))
model.to(device)
if world_size > 1:
logging.info("Using DDP")
@ -1379,14 +1398,14 @@ def run(rank, world_size, args):
librispeech.valid_dataloaders(gigaspeech_dev_cuts),
]
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:

View File

@ -0,0 +1 @@
../zipformer/model.py

View File

@ -0,0 +1 @@
../zipformer/optim.py

View File

@ -23,6 +23,7 @@ import random
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
@ -517,6 +518,94 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
return ans
class LoRALayer:
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class ScaledLinear_lora(nn.Linear, LoRALayer):
def __init__(
self,
in_features: int,
out_features: int,
r: int=0,
fan_in_fan_out: bool=False,
lora_alpha: int=1,
lora_dropout: float=0.0,
initial_scale: float = 1.0,
merge_weights: bool = True,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.initial_scale = initial_scale
self.fan_in_fan_out = fan_in_fan_out
if r > 0:
self.lora_A = nn.Parameter(torch.full((r, in_features), 0.0))
self.lora_B = nn.Parameter(torch.full((out_features, r), 0.0))
self.scaling = self.lora_alpha / self.r
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
# initialize the parameters
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
initial_scale = self.initial_scale
with torch.no_grad():
self.weight[:] *= initial_scale
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1 * initial_scale, 0.1 * initial_scale)
if hasattr(self, 'lora_A'):
# initialize B the same way as the default for nn.Linear and A to zero
# this is different than what is described in the paper but should not affect performance
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool=True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if mode:
# We don't want the weights to be merged in training mode
if self.merge_weights and self.merged:
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
else:
# When evaluating the model, we merge the weights for simplicity
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
delta_result = self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)
return result + delta_result * self.scaling
else:
return F.linear(x, T(self.weight), bias=self.bias)
def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
"""

View File

@ -30,6 +30,7 @@ from scaling import (
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
ScaledLinear_lora
)
from scaling import (
ActivationDropoutAndLinear,
@ -116,6 +117,8 @@ class Zipformer2(EncoderInterface):
causal: bool = False,
chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1],
use_lora: bool = True,
lora_r: int = 0,
) -> None:
super(Zipformer2, self).__init__()
@ -152,6 +155,8 @@ class Zipformer2(EncoderInterface):
self.chunk_size = chunk_size
self.left_context_frames = left_context_frames
self.lora_r = lora_r if use_lora else 0
for u, d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
@ -171,6 +176,7 @@ class Zipformer2(EncoderInterface):
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i],
causal=causal,
lora_r=self.lora_r,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
@ -589,6 +595,9 @@ class Zipformer2EncoderLayer(nn.Module):
bypass_skip_rate: FloatLike = ScheduledFloat(
(0.0, 0.5), (4000.0, 0.02), default=0
),
lora_r: int = 0,
lora_alpha: int = 4,
lora_dropout: float = 0.0,
) -> None:
super(Zipformer2EncoderLayer, self).__init__()
self.embed_dim = embed_dim
@ -620,6 +629,9 @@ class Zipformer2EncoderLayer(nn.Module):
query_head_dim=query_head_dim,
pos_head_dim=pos_head_dim,
dropout=0.0,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
@ -1508,6 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
dropout: dropout probability for attn_output_weights. Default: 0.0.
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
any given call to forward(), in training time.
lora_r: the bottleneck dimension of LoRA
"""
def __init__(
@ -1519,6 +1532,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_head_dim: int,
dropout: float = 0.0,
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
lora_r: int = 0,
lora_alpha: int = 4,
lora_dropout: float=0.0
) -> None:
super().__init__()
self.embed_dim = embed_dim
@ -1537,8 +1553,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# dividing it between the query and key. Note: this module is intended
# to be used with the ScaledAdam optimizer; with most other optimizers,
# it would be necessary to apply the scaling factor in the forward function.
self.in_proj = ScaledLinear(
embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
# self.in_proj = ScaledLinear(
# embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
# )
self.in_proj = ScaledLinear_lora(
in_features=embed_dim,
out_features=in_proj_dim,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
initial_scale=query_head_dim**-0.25,
bias=True,
)
self.whiten_keys = Whiten(