from local

This commit is contained in:
dohe0342 2023-06-09 15:42:02 +09:00
parent 4929de22dc
commit 8f207043c5
5 changed files with 94 additions and 143 deletions

View File

@ -38,6 +38,24 @@ from convolution import ConvolutionModule
logger = logging.getLogger().setLevel(logging.INFO)
class LoRAHook():
def __init__(self, module, embedding_dim, rank, lora_alpha):
self.hook = module.register_forward_hook(self.hook_fn)
self.lora = LoRAModule(
embedding_dim=embedding_dim,
rank=rank,
lora_alpha=lora_alpha,
)
def hook_fn(self, module, input, output):
lora_out = self.lora(input[0])
output += lora_out
def save_checkpoint(self, i, iter_, save_dir):
if isinstance(self.lora, DDP):
lora = self.lora.module
torch.save(lora.state_dict(), f"{save_dir}/lora_{iter_}_{i}.pt")
class TransformerEncoderAdapter(TransformerEncoder):
def __init__(self, args: Wav2Vec2Config):
super().__init__(args)

View File

@ -101,6 +101,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from data2vec_encoder import FairSeqData2VecEncoder
from data2vec_audio import LoRAModule, LoRAHook
from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints
@ -123,8 +124,8 @@ from icefall.utils import (
)
import wandb
import fairseq
#from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -138,26 +139,33 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
model.encoder.num_updates = int(batch_count)
def add_adapter_arguments(parser: argparse.ArgumentParser):
def add_pea_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--add-adapter",
"--adapter",
type=str2bool,
default=False,
help="add adapter to rep model's encoder"
)
parser.add_argument(
"--adapter-lr",
type=float,
default=0.0001,
help="adapter learning rate"
"--bitfit",
type=str2bool,
default=False,
help="bias only training for PEA"
)
parser.add_argument(
"--lora",
type=str2bool,
default=False,
help="Low Rank Adaptation training for PEA"
)
parser.add_argument(
"--gender",
type=str,
default='male',
help="select gender"
"--pea-lr",
type=float,
default=0.0001,
help="PEA learning rate"
)
@ -314,12 +322,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--prompt",
type=str2bool,
default=False,
)
def get_parser():
parser = argparse.ArgumentParser(
@ -528,10 +530,16 @@ def get_parser():
default=True,
help="Whether to use half precision training.",
)
parser.add_argument(
"--pea",
type=str2bool,
default=True,
help="Whether to train parameter efficient adaptation",
)
add_model_arguments(parser)
add_rep_arguments(parser)
add_adapter_arguments(parser)
add_pea_arguments(parser)
return parser
@ -588,7 +596,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 5,
"log_interval": 20,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer
@ -663,7 +671,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
@ -672,8 +680,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
prompt=params.prompt,
sid=params.spk_id,
)
return model
@ -714,7 +720,7 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
elif params.add_adapter:
elif params.pea:
filename = params.exp_dir / f"../d2v-base-T.pt"
else:
return None
@ -727,7 +733,7 @@ def load_checkpoint_if_available(
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
strict=True if not params.add_adapter else False,
strict=True if not params.pea else False,
)
keys = [
@ -1001,6 +1007,7 @@ def train_one_epoch(
world_size: int = 1,
rank: int = 0,
wb = None,
lora_modules = None,
) -> None:
"""Train the model for one epoch.
@ -1100,54 +1107,23 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5:
return
'''
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
'''
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
'''
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
'''
if rank == 0:
for i, lora in enumerate(lora_modules):
lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
'''
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
'''
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
@ -1156,16 +1132,7 @@ def train_one_epoch(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
#if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
# wb.log({"valid/loss": 10000})
# raise RuntimeError(
# f"divergence... exiting: loss={loss}"
# )
if batch_idx % (params.log_interval*params.accum_grads) == 0:
#for n, p in model.named_parameters():
# if 'adapter' in n:
# print(p)
if params.multi_optim:
cur_enc_lr = scheduler_enc.get_last_lr()[0]
cur_dec_lr = scheduler_dec.get_last_lr()[0]
@ -1223,36 +1190,6 @@ def train_one_epoch(
wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel})
wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel})
'''
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if wb is not None and rank == 0:
numel = 1 / (params.accum_grads * valid_info["utterances"])
#wb.log({"valid/loss": valid_info["loss"]*numel})
wb.log({"valid/loss": numel*(valid_info["simple_loss"]
+valid_info["pruned_loss"]
+valid_info["ctc_loss"]
)})
wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel})
wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel})
wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel})
'''
loss_value = tot_loss["loss"] / tot_loss["utterances"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
@ -1449,17 +1386,6 @@ def run(rank, world_size, args, wb=None):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
'''
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
'''
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
@ -1506,7 +1432,7 @@ def run(rank, world_size, args, wb=None):
cleanup_dist()
def run_adapter(rank, world_size, args, wb=None):
def run_pea(rank, world_size, args, wb=None):
"""
Args:
rank:
@ -1557,8 +1483,7 @@ def run_adapter(rank, world_size, args, wb=None):
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
#model_avg = copy.deepcopy(model).to(torch.float64)
model_avg = None
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
@ -1570,33 +1495,40 @@ def run_adapter(rank, world_size, args, wb=None):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
adapter_names = []
adapter_param = []
for n, p in model.named_parameters():
if 'q_proj.bias' in n or 'fc1.bias' in n:
adapter_names.append(n)
adapter_param.append(p)
else:
p.requires_grad = False
'''
if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
adapter_names.append(n)
adapter_param.append(p)
elif 'joiner' in n or 'simple' in n or 'ctc' in n:
p.requires_grad = True
else:
p.requires_grad = False
'''
optimizer_adapter = ScaledAdam(
adapter_param,
lr=params.adapter_lr,
clipping_scale=5.0,
parameters_names=[adapter_names],
)
lora_modules = []
for modules in model.modules():
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
for module in modules.modules():
if isinstance(module, torch.nn.Linear):
lora_modules.append(LoRAHook(
module,
embedding_dim=args.encoder_dim,
rank=args.rank,
lora_alpha=args.lora_alpha,
))
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
optimizer, scheduler = optimizer_adapter, scheduler_adapter
if world_size > 1:
logging.info("Using DDP for LoRA")
for module in lora_modules:
module.lora = module.lora.to(device)
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
pea_names = []
pea_param = []
for i, module in enumerate(lora_modules):
for n, p in module.lora.named_parameters():
new_n = str(i) + n
pea_names.append(new_n)
pea_param.append(p)
optimizer_pea = ScaledAdam(
pea_param,
lr=params.pea_lr,
clipping_scale=5.0,
parameters_names=[pea_names],
)
scheduler_pea = Eden(optimizer_pea, 10000, 7)
optimizer, scheduler = optimizer_pea, scheduler_pea
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.vox_cuts(option=params.spk_id)
@ -1611,7 +1543,7 @@ def run_adapter(rank, world_size, args, wb=None):
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
valid_dl = librispeech.valid_dataloaders(valid_cuts)
@ -1643,6 +1575,7 @@ def run_adapter(rank, world_size, args, wb=None):
world_size=world_size,
rank=rank,
wb=wb,
lora_modules=lora_modules,
)
if params.print_diagnostics:
@ -1746,13 +1679,13 @@ def main():
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run if not args.add_adapter else run_adapter,
mp.spawn(run if not args.pea else run_pea,
args=(world_size, args, wb),
nprocs=world_size,
join=True
)
else:
if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb)
if args.pea: run_pea(rank=0, world_size=1, args=args, wb=wb)
else: run(rank=0, world_size=1, args=args, wb=wb)
torch.set_num_threads(1)