mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
4929de22dc
commit
8f207043c5
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -38,6 +38,24 @@ from convolution import ConvolutionModule
|
||||
logger = logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
|
||||
class LoRAHook():
|
||||
def __init__(self, module, embedding_dim, rank, lora_alpha):
|
||||
self.hook = module.register_forward_hook(self.hook_fn)
|
||||
self.lora = LoRAModule(
|
||||
embedding_dim=embedding_dim,
|
||||
rank=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
)
|
||||
def hook_fn(self, module, input, output):
|
||||
lora_out = self.lora(input[0])
|
||||
output += lora_out
|
||||
|
||||
def save_checkpoint(self, i, iter_, save_dir):
|
||||
if isinstance(self.lora, DDP):
|
||||
lora = self.lora.module
|
||||
torch.save(lora.state_dict(), f"{save_dir}/lora_{iter_}_{i}.pt")
|
||||
|
||||
|
||||
class TransformerEncoderAdapter(TransformerEncoder):
|
||||
def __init__(self, args: Wav2Vec2Config):
|
||||
super().__init__(args)
|
||||
|
||||
@ -101,6 +101,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
from data2vec_encoder import FairSeqData2VecEncoder
|
||||
from data2vec_audio import LoRAModule, LoRAHook
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
@ -123,8 +124,8 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
import wandb
|
||||
import fairseq
|
||||
|
||||
#from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
@ -138,26 +139,33 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
model.encoder.num_updates = int(batch_count)
|
||||
|
||||
|
||||
def add_adapter_arguments(parser: argparse.ArgumentParser):
|
||||
def add_pea_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--add-adapter",
|
||||
"--adapter",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="add adapter to rep model's encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adapter-lr",
|
||||
type=float,
|
||||
default=0.0001,
|
||||
help="adapter learning rate"
|
||||
"--bitfit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="bias only training for PEA"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Low Rank Adaptation training for PEA"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gender",
|
||||
type=str,
|
||||
default='male',
|
||||
help="select gender"
|
||||
"--pea-lr",
|
||||
type=float,
|
||||
default=0.0001,
|
||||
help="PEA learning rate"
|
||||
)
|
||||
|
||||
|
||||
@ -314,12 +322,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -528,10 +530,16 @@ def get_parser():
|
||||
default=True,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--pea",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to train parameter efficient adaptation",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
add_rep_arguments(parser)
|
||||
add_adapter_arguments(parser)
|
||||
add_pea_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@ -588,7 +596,7 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 5,
|
||||
"log_interval": 20,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
@ -663,7 +671,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
@ -672,8 +680,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
prompt=params.prompt,
|
||||
sid=params.spk_id,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -714,7 +720,7 @@ def load_checkpoint_if_available(
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
elif params.add_adapter:
|
||||
elif params.pea:
|
||||
filename = params.exp_dir / f"../d2v-base-T.pt"
|
||||
else:
|
||||
return None
|
||||
@ -727,7 +733,7 @@ def load_checkpoint_if_available(
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
strict=True if not params.add_adapter else False,
|
||||
strict=True if not params.pea else False,
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -1001,6 +1007,7 @@ def train_one_epoch(
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
wb = None,
|
||||
lora_modules = None,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
@ -1100,54 +1107,23 @@ def train_one_epoch(
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
'''
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
'''
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
'''
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
'''
|
||||
|
||||
if rank == 0:
|
||||
for i, lora in enumerate(lora_modules):
|
||||
lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
'''
|
||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
'''
|
||||
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
@ -1156,16 +1132,7 @@ def train_one_epoch(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
#if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
|
||||
# wb.log({"valid/loss": 10000})
|
||||
# raise RuntimeError(
|
||||
# f"divergence... exiting: loss={loss}"
|
||||
# )
|
||||
|
||||
if batch_idx % (params.log_interval*params.accum_grads) == 0:
|
||||
#for n, p in model.named_parameters():
|
||||
# if 'adapter' in n:
|
||||
# print(p)
|
||||
if params.multi_optim:
|
||||
cur_enc_lr = scheduler_enc.get_last_lr()[0]
|
||||
cur_dec_lr = scheduler_dec.get_last_lr()[0]
|
||||
@ -1223,36 +1190,6 @@ def train_one_epoch(
|
||||
wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel})
|
||||
wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel})
|
||||
|
||||
'''
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if wb is not None and rank == 0:
|
||||
numel = 1 / (params.accum_grads * valid_info["utterances"])
|
||||
#wb.log({"valid/loss": valid_info["loss"]*numel})
|
||||
wb.log({"valid/loss": numel*(valid_info["simple_loss"]
|
||||
+valid_info["pruned_loss"]
|
||||
+valid_info["ctc_loss"]
|
||||
)})
|
||||
wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel})
|
||||
wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel})
|
||||
wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel})
|
||||
'''
|
||||
loss_value = tot_loss["loss"] / tot_loss["utterances"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -1449,17 +1386,6 @@ def run(rank, world_size, args, wb=None):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
'''
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
'''
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
@ -1506,7 +1432,7 @@ def run(rank, world_size, args, wb=None):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def run_adapter(rank, world_size, args, wb=None):
|
||||
def run_pea(rank, world_size, args, wb=None):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
@ -1557,8 +1483,7 @@ def run_adapter(rank, world_size, args, wb=None):
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
#model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
model_avg = None
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
@ -1570,33 +1495,40 @@ def run_adapter(rank, world_size, args, wb=None):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
adapter_names = []
|
||||
adapter_param = []
|
||||
for n, p in model.named_parameters():
|
||||
if 'q_proj.bias' in n or 'fc1.bias' in n:
|
||||
adapter_names.append(n)
|
||||
adapter_param.append(p)
|
||||
else:
|
||||
p.requires_grad = False
|
||||
|
||||
'''
|
||||
if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
|
||||
adapter_names.append(n)
|
||||
adapter_param.append(p)
|
||||
elif 'joiner' in n or 'simple' in n or 'ctc' in n:
|
||||
p.requires_grad = True
|
||||
else:
|
||||
p.requires_grad = False
|
||||
'''
|
||||
optimizer_adapter = ScaledAdam(
|
||||
adapter_param,
|
||||
lr=params.adapter_lr,
|
||||
clipping_scale=5.0,
|
||||
parameters_names=[adapter_names],
|
||||
)
|
||||
lora_modules = []
|
||||
for modules in model.modules():
|
||||
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
|
||||
for module in modules.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
lora_modules.append(LoRAHook(
|
||||
module,
|
||||
embedding_dim=args.encoder_dim,
|
||||
rank=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
))
|
||||
|
||||
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
|
||||
optimizer, scheduler = optimizer_adapter, scheduler_adapter
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP for LoRA")
|
||||
for module in lora_modules:
|
||||
module.lora = module.lora.to(device)
|
||||
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
|
||||
|
||||
pea_names = []
|
||||
pea_param = []
|
||||
for i, module in enumerate(lora_modules):
|
||||
for n, p in module.lora.named_parameters():
|
||||
new_n = str(i) + n
|
||||
pea_names.append(new_n)
|
||||
pea_param.append(p)
|
||||
|
||||
optimizer_pea = ScaledAdam(
|
||||
pea_param,
|
||||
lr=params.pea_lr,
|
||||
clipping_scale=5.0,
|
||||
parameters_names=[pea_names],
|
||||
)
|
||||
scheduler_pea = Eden(optimizer_pea, 10000, 7)
|
||||
optimizer, scheduler = optimizer_pea, scheduler_pea
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_cuts = librispeech.vox_cuts(option=params.spk_id)
|
||||
@ -1611,7 +1543,7 @@ def run_adapter(rank, world_size, args, wb=None):
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
|
||||
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
@ -1643,6 +1575,7 @@ def run_adapter(rank, world_size, args, wb=None):
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
wb=wb,
|
||||
lora_modules=lora_modules,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
@ -1746,13 +1679,13 @@ def main():
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run if not args.add_adapter else run_adapter,
|
||||
mp.spawn(run if not args.pea else run_pea,
|
||||
args=(world_size, args, wb),
|
||||
nprocs=world_size,
|
||||
join=True
|
||||
)
|
||||
else:
|
||||
if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb)
|
||||
if args.pea: run_pea(rank=0, world_size=1, args=args, wb=wb)
|
||||
else: run(rank=0, world_size=1, args=args, wb=wb)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
Loading…
x
Reference in New Issue
Block a user