mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
first version of lora, could just run
This commit is contained in:
parent
87b19c0bd2
commit
a56034799b
@ -479,18 +479,18 @@ class LibriSpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def gigaspeech_subset_small_cuts(self) -> CutSet:
|
def gigaspeech_subset_small_cuts(self) -> CutSet:
|
||||||
logging.info("About to get Gigaspeech subset-S cuts")
|
logging.info("About to get Gigaspeech subset-S cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def gigaspeech_dev_cuts(self) -> CutSet:
|
def gigaspeech_dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get Gigaspeech dev cuts")
|
logging.info("About to get Gigaspeech dev cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
|
self.args.manifest_dir / "cuts_DEV.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def gigaspeech_test_cuts(self) -> CutSet:
|
def gigaspeech_test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get Gigaspeech test cuts")
|
logging.info("About to get Gigaspeech test cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
|
self.args.manifest_dir / "cuts_TEST.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
@ -134,7 +134,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=True,
|
default=True,
|
||||||
help="If true, finetune from a pre-trained checkpoint",
|
help="If true, finetune from a pre-trained checkpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-mux",
|
"--use-mux",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -390,7 +390,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr",
|
"--base-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0045,
|
default=0.045,
|
||||||
help="""The base learning rate.
|
help="""The base learning rate.
|
||||||
It is set to a very small value as we are doing fine-tuning""",
|
It is set to a very small value as we are doing fine-tuning""",
|
||||||
)
|
)
|
||||||
@ -646,6 +646,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
chunk_size=_to_int_tuple(params.chunk_size),
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
|
use_lora=params.use_lora,
|
||||||
|
lora_r=params.lora_r if params.use_lora else 0,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -1041,6 +1043,12 @@ def train_one_epoch(
|
|||||||
|
|
||||||
saved_bad_model = False
|
saved_bad_model = False
|
||||||
|
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if "lora" in name:
|
||||||
|
m.training = True
|
||||||
|
else:
|
||||||
|
m.training = False
|
||||||
|
|
||||||
def save_bad_model(suffix: str = ""):
|
def save_bad_model(suffix: str = ""):
|
||||||
save_checkpoint_impl(
|
save_checkpoint_impl(
|
||||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
@ -1177,7 +1185,6 @@ def train_one_epoch(
|
|||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
|
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
|
||||||
)
|
)
|
||||||
@ -1188,6 +1195,7 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
|
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
model.train()
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
@ -1257,7 +1265,7 @@ def run(rank, world_size, args):
|
|||||||
assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
|
assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
|
||||||
modules = params.init_modules.split(",") if params.init_modules else None
|
modules = params.init_modules.split(",") if params.init_modules else None
|
||||||
checkpoints = load_model_params(
|
checkpoints = load_model_params(
|
||||||
ckpt=params.finetune_ckpt, model=model, init_modules=modules
|
ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False
|
||||||
)
|
)
|
||||||
# Need to update the model_avg if use initialisation
|
# Need to update the model_avg if use initialisation
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -1270,6 +1278,17 @@ def run(rank, world_size, args):
|
|||||||
params=params, model=model, model_avg=model_avg
|
params=params, model=model, model_avg=model_avg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# keep the original model untouched, only update the adapters
|
||||||
|
num_trainable = 0
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if "lora_A" in name or "lora_B" in name:
|
||||||
|
p.requires_grad = True
|
||||||
|
num_trainable += p.numel()
|
||||||
|
else:
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100))
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
@ -1379,14 +1398,14 @@ def run(rank, world_size, args):
|
|||||||
librispeech.valid_dataloaders(gigaspeech_dev_cuts),
|
librispeech.valid_dataloaders(gigaspeech_dev_cuts),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
# model=model,
|
||||||
train_dl=train_dl,
|
# train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
sp=sp,
|
# sp=sp,
|
||||||
params=params,
|
# params=params,
|
||||||
)
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
1
egs/librispeech/ASR/zipformer_lora/model.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/model.py
|
1
egs/librispeech/ASR/zipformer_lora/optim.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/optim.py
|
@ -23,6 +23,7 @@ import random
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
@ -517,6 +518,94 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
|||||||
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
class LoRALayer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
r: int,
|
||||||
|
lora_alpha: int,
|
||||||
|
lora_dropout: float,
|
||||||
|
merge_weights: bool,
|
||||||
|
):
|
||||||
|
self.r = r
|
||||||
|
self.lora_alpha = lora_alpha
|
||||||
|
# Optional dropout
|
||||||
|
if lora_dropout > 0.:
|
||||||
|
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||||
|
else:
|
||||||
|
self.lora_dropout = lambda x: x
|
||||||
|
# Mark the weight as unmerged
|
||||||
|
self.merged = False
|
||||||
|
self.merge_weights = merge_weights
|
||||||
|
|
||||||
|
class ScaledLinear_lora(nn.Linear, LoRALayer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
r: int=0,
|
||||||
|
fan_in_fan_out: bool=False,
|
||||||
|
lora_alpha: int=1,
|
||||||
|
lora_dropout: float=0.0,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
merge_weights: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||||
|
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
||||||
|
merge_weights=merge_weights)
|
||||||
|
|
||||||
|
self.initial_scale = initial_scale
|
||||||
|
self.fan_in_fan_out = fan_in_fan_out
|
||||||
|
if r > 0:
|
||||||
|
self.lora_A = nn.Parameter(torch.full((r, in_features), 0.0))
|
||||||
|
self.lora_B = nn.Parameter(torch.full((out_features, r), 0.0))
|
||||||
|
self.scaling = self.lora_alpha / self.r
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
# initialize the parameters
|
||||||
|
nn.Linear.reset_parameters(self)
|
||||||
|
if hasattr(self, "lora_A"):
|
||||||
|
initial_scale = self.initial_scale
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight[:] *= initial_scale
|
||||||
|
if self.bias is not None:
|
||||||
|
nn.init.uniform_(self.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
||||||
|
if hasattr(self, 'lora_A'):
|
||||||
|
# initialize B the same way as the default for nn.Linear and A to zero
|
||||||
|
# this is different than what is described in the paper but should not affect performance
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
|
def train(self, mode: bool=True):
|
||||||
|
def T(w):
|
||||||
|
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
||||||
|
nn.Linear.train(self, mode)
|
||||||
|
if mode:
|
||||||
|
# We don't want the weights to be merged in training mode
|
||||||
|
if self.merge_weights and self.merged:
|
||||||
|
if self.r > 0:
|
||||||
|
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
self.merged = False
|
||||||
|
else:
|
||||||
|
# When evaluating the model, we merge the weights for simplicity
|
||||||
|
if self.merge_weights and not self.merged:
|
||||||
|
# Merge the weights and mark it
|
||||||
|
if self.r > 0:
|
||||||
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
self.merged = True
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
def T(w):
|
||||||
|
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
||||||
|
if self.r > 0 and not self.merged:
|
||||||
|
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||||
|
delta_result = self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)
|
||||||
|
return result + delta_result * self.scaling
|
||||||
|
else:
|
||||||
|
return F.linear(x, T(self.weight), bias=self.bias)
|
||||||
|
|
||||||
def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
|
def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
|
||||||
"""
|
"""
|
||||||
|
@ -30,6 +30,7 @@ from scaling import (
|
|||||||
)
|
)
|
||||||
from scaling import (
|
from scaling import (
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
|
ScaledLinear_lora
|
||||||
)
|
)
|
||||||
from scaling import (
|
from scaling import (
|
||||||
ActivationDropoutAndLinear,
|
ActivationDropoutAndLinear,
|
||||||
@ -116,6 +117,8 @@ class Zipformer2(EncoderInterface):
|
|||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
chunk_size: Tuple[int] = [-1],
|
chunk_size: Tuple[int] = [-1],
|
||||||
left_context_frames: Tuple[int] = [-1],
|
left_context_frames: Tuple[int] = [-1],
|
||||||
|
use_lora: bool = True,
|
||||||
|
lora_r: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer2, self).__init__()
|
super(Zipformer2, self).__init__()
|
||||||
|
|
||||||
@ -152,6 +155,8 @@ class Zipformer2(EncoderInterface):
|
|||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.left_context_frames = left_context_frames
|
self.left_context_frames = left_context_frames
|
||||||
|
|
||||||
|
self.lora_r = lora_r if use_lora else 0
|
||||||
|
|
||||||
for u, d in zip(encoder_unmasked_dim, encoder_dim):
|
for u, d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
|
|
||||||
@ -171,6 +176,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
cnn_module_kernel=cnn_module_kernel[i],
|
cnn_module_kernel=cnn_module_kernel[i],
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
lora_r=self.lora_r,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||||
@ -589,6 +595,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
bypass_skip_rate: FloatLike = ScheduledFloat(
|
bypass_skip_rate: FloatLike = ScheduledFloat(
|
||||||
(0.0, 0.5), (4000.0, 0.02), default=0
|
(0.0, 0.5), (4000.0, 0.02), default=0
|
||||||
),
|
),
|
||||||
|
lora_r: int = 0,
|
||||||
|
lora_alpha: int = 4,
|
||||||
|
lora_dropout: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer2EncoderLayer, self).__init__()
|
super(Zipformer2EncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -620,6 +629,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
query_head_dim=query_head_dim,
|
query_head_dim=query_head_dim,
|
||||||
pos_head_dim=pos_head_dim,
|
pos_head_dim=pos_head_dim,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
||||||
@ -1508,6 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||||
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
||||||
any given call to forward(), in training time.
|
any given call to forward(), in training time.
|
||||||
|
lora_r: the bottleneck dimension of LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1519,6 +1532,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
pos_head_dim: int,
|
pos_head_dim: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
||||||
|
lora_r: int = 0,
|
||||||
|
lora_alpha: int = 4,
|
||||||
|
lora_dropout: float=0.0
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -1537,8 +1553,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# dividing it between the query and key. Note: this module is intended
|
# dividing it between the query and key. Note: this module is intended
|
||||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||||
# it would be necessary to apply the scaling factor in the forward function.
|
# it would be necessary to apply the scaling factor in the forward function.
|
||||||
self.in_proj = ScaledLinear(
|
# self.in_proj = ScaledLinear(
|
||||||
embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
|
# embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||||
|
# )
|
||||||
|
self.in_proj = ScaledLinear_lora(
|
||||||
|
in_features=embed_dim,
|
||||||
|
out_features=in_proj_dim,
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
initial_scale=query_head_dim**-0.25,
|
||||||
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.whiten_keys = Whiten(
|
self.whiten_keys = Whiten(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user