mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
4929de22dc
commit
8f207043c5
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -38,6 +38,24 @@ from convolution import ConvolutionModule
|
|||||||
logger = logging.getLogger().setLevel(logging.INFO)
|
logger = logging.getLogger().setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAHook():
|
||||||
|
def __init__(self, module, embedding_dim, rank, lora_alpha):
|
||||||
|
self.hook = module.register_forward_hook(self.hook_fn)
|
||||||
|
self.lora = LoRAModule(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
)
|
||||||
|
def hook_fn(self, module, input, output):
|
||||||
|
lora_out = self.lora(input[0])
|
||||||
|
output += lora_out
|
||||||
|
|
||||||
|
def save_checkpoint(self, i, iter_, save_dir):
|
||||||
|
if isinstance(self.lora, DDP):
|
||||||
|
lora = self.lora.module
|
||||||
|
torch.save(lora.state_dict(), f"{save_dir}/lora_{iter_}_{i}.pt")
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderAdapter(TransformerEncoder):
|
class TransformerEncoderAdapter(TransformerEncoder):
|
||||||
def __init__(self, args: Wav2Vec2Config):
|
def __init__(self, args: Wav2Vec2Config):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|||||||
@ -101,6 +101,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
from data2vec_encoder import FairSeqData2VecEncoder
|
from data2vec_encoder import FairSeqData2VecEncoder
|
||||||
|
from data2vec_audio import LoRAModule, LoRAHook
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import remove_checkpoints
|
from icefall.checkpoint import remove_checkpoints
|
||||||
@ -123,8 +124,8 @@ from icefall.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
import fairseq
|
||||||
|
|
||||||
#from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
@ -138,26 +139,33 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
|||||||
model.encoder.num_updates = int(batch_count)
|
model.encoder.num_updates = int(batch_count)
|
||||||
|
|
||||||
|
|
||||||
def add_adapter_arguments(parser: argparse.ArgumentParser):
|
def add_pea_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--add-adapter",
|
"--adapter",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="add adapter to rep model's encoder"
|
help="add adapter to rep model's encoder"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-lr",
|
"--bitfit",
|
||||||
type=float,
|
type=str2bool,
|
||||||
default=0.0001,
|
default=False,
|
||||||
help="adapter learning rate"
|
help="bias only training for PEA"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gender",
|
"--lora",
|
||||||
type=str,
|
type=str2bool,
|
||||||
default='male',
|
default=False,
|
||||||
help="select gender"
|
help="Low Rank Adaptation training for PEA"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pea-lr",
|
||||||
|
type=float,
|
||||||
|
default=0.0001,
|
||||||
|
help="PEA learning rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -314,12 +322,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -529,9 +531,15 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pea",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to train parameter efficient adaptation",
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
add_rep_arguments(parser)
|
add_rep_arguments(parser)
|
||||||
add_adapter_arguments(parser)
|
add_pea_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -588,7 +596,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 5,
|
"log_interval": 20,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
@ -672,8 +680,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
prompt=params.prompt,
|
|
||||||
sid=params.spk_id,
|
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -714,7 +720,7 @@ def load_checkpoint_if_available(
|
|||||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||||
elif params.start_epoch > 1:
|
elif params.start_epoch > 1:
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
elif params.add_adapter:
|
elif params.pea:
|
||||||
filename = params.exp_dir / f"../d2v-base-T.pt"
|
filename = params.exp_dir / f"../d2v-base-T.pt"
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -727,7 +733,7 @@ def load_checkpoint_if_available(
|
|||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
strict=True if not params.add_adapter else False,
|
strict=True if not params.pea else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -1001,6 +1007,7 @@ def train_one_epoch(
|
|||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
wb = None,
|
wb = None,
|
||||||
|
lora_modules = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train the model for one epoch.
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
@ -1100,54 +1107,23 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
'''
|
|
||||||
if (
|
|
||||||
rank == 0
|
|
||||||
and params.batch_idx_train > 0
|
|
||||||
and params.batch_idx_train % params.average_period == 0
|
|
||||||
):
|
|
||||||
update_averaged_model(
|
|
||||||
params=params,
|
|
||||||
model_cur=model,
|
|
||||||
model_avg=model_avg,
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
):
|
):
|
||||||
params.cur_batch_idx = batch_idx
|
params.cur_batch_idx = batch_idx
|
||||||
save_checkpoint_with_global_batch_idx(
|
|
||||||
out_dir=params.exp_dir,
|
|
||||||
global_batch_idx=params.batch_idx_train,
|
|
||||||
model=model,
|
|
||||||
model_avg=model_avg,
|
|
||||||
params=params,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
sampler=train_dl.sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
del params.cur_batch_idx
|
del params.cur_batch_idx
|
||||||
'''
|
|
||||||
remove_checkpoints(
|
if rank == 0:
|
||||||
out_dir=params.exp_dir,
|
for i, lora in enumerate(lora_modules):
|
||||||
topk=params.keep_last_k,
|
lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if batch_idx % 100 == 0 and params.use_fp16:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
'''
|
|
||||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
|
||||||
'''
|
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
@ -1156,16 +1132,7 @@ def train_one_epoch(
|
|||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
)
|
)
|
||||||
|
|
||||||
#if params.batch_idx_train > 4000 and loss > 300 and params.wandb:
|
|
||||||
# wb.log({"valid/loss": 10000})
|
|
||||||
# raise RuntimeError(
|
|
||||||
# f"divergence... exiting: loss={loss}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
if batch_idx % (params.log_interval*params.accum_grads) == 0:
|
if batch_idx % (params.log_interval*params.accum_grads) == 0:
|
||||||
#for n, p in model.named_parameters():
|
|
||||||
# if 'adapter' in n:
|
|
||||||
# print(p)
|
|
||||||
if params.multi_optim:
|
if params.multi_optim:
|
||||||
cur_enc_lr = scheduler_enc.get_last_lr()[0]
|
cur_enc_lr = scheduler_enc.get_last_lr()[0]
|
||||||
cur_dec_lr = scheduler_dec.get_last_lr()[0]
|
cur_dec_lr = scheduler_dec.get_last_lr()[0]
|
||||||
@ -1223,36 +1190,6 @@ def train_one_epoch(
|
|||||||
wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel})
|
wb.log({"train/pruned_loss": loss_info["pruned_loss"]*numel})
|
||||||
wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel})
|
wb.log({"train/ctc_loss": loss_info["ctc_loss"]*numel})
|
||||||
|
|
||||||
'''
|
|
||||||
logging.info("Computing validation loss")
|
|
||||||
valid_info = compute_validation_loss(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
valid_dl=valid_dl,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
model.train()
|
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
|
||||||
logging.info(
|
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
|
||||||
valid_info.write_summary(
|
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
|
|
||||||
if wb is not None and rank == 0:
|
|
||||||
numel = 1 / (params.accum_grads * valid_info["utterances"])
|
|
||||||
#wb.log({"valid/loss": valid_info["loss"]*numel})
|
|
||||||
wb.log({"valid/loss": numel*(valid_info["simple_loss"]
|
|
||||||
+valid_info["pruned_loss"]
|
|
||||||
+valid_info["ctc_loss"]
|
|
||||||
)})
|
|
||||||
wb.log({"valid/simple_loss": valid_info["simple_loss"]*numel})
|
|
||||||
wb.log({"valid/pruned_loss": valid_info["pruned_loss"]*numel})
|
|
||||||
wb.log({"valid/ctc_loss": valid_info["ctc_loss"]*numel})
|
|
||||||
'''
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["utterances"]
|
loss_value = tot_loss["loss"] / tot_loss["utterances"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -1449,17 +1386,6 @@ def run(rank, world_size, args, wb=None):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
'''
|
|
||||||
if not params.print_diagnostics:
|
|
||||||
scan_pessimistic_batches_for_oom(
|
|
||||||
model=model,
|
|
||||||
train_dl=train_dl,
|
|
||||||
optimizer=optimizer,
|
|
||||||
sp=sp,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
@ -1506,7 +1432,7 @@ def run(rank, world_size, args, wb=None):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
def run_adapter(rank, world_size, args, wb=None):
|
def run_pea(rank, world_size, args, wb=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
rank:
|
rank:
|
||||||
@ -1557,8 +1483,7 @@ def run_adapter(rank, world_size, args, wb=None):
|
|||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# model_avg is only used with rank 0
|
# model_avg is only used with rank 0
|
||||||
#model_avg = copy.deepcopy(model).to(torch.float64)
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
model_avg = None
|
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
@ -1570,33 +1495,40 @@ def run_adapter(rank, world_size, args, wb=None):
|
|||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
adapter_names = []
|
lora_modules = []
|
||||||
adapter_param = []
|
for modules in model.modules():
|
||||||
for n, p in model.named_parameters():
|
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
|
||||||
if 'q_proj.bias' in n or 'fc1.bias' in n:
|
for module in modules.modules():
|
||||||
adapter_names.append(n)
|
if isinstance(module, torch.nn.Linear):
|
||||||
adapter_param.append(p)
|
lora_modules.append(LoRAHook(
|
||||||
else:
|
module,
|
||||||
p.requires_grad = False
|
embedding_dim=args.encoder_dim,
|
||||||
|
rank=args.rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
))
|
||||||
|
|
||||||
'''
|
if world_size > 1:
|
||||||
if 'adapters' in n:# or 'joiner' in n or 'simple' in n or 'ctc' in n:
|
logging.info("Using DDP for LoRA")
|
||||||
adapter_names.append(n)
|
for module in lora_modules:
|
||||||
adapter_param.append(p)
|
module.lora = module.lora.to(device)
|
||||||
elif 'joiner' in n or 'simple' in n or 'ctc' in n:
|
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
|
||||||
p.requires_grad = True
|
|
||||||
else:
|
pea_names = []
|
||||||
p.requires_grad = False
|
pea_param = []
|
||||||
'''
|
for i, module in enumerate(lora_modules):
|
||||||
optimizer_adapter = ScaledAdam(
|
for n, p in module.lora.named_parameters():
|
||||||
adapter_param,
|
new_n = str(i) + n
|
||||||
lr=params.adapter_lr,
|
pea_names.append(new_n)
|
||||||
|
pea_param.append(p)
|
||||||
|
|
||||||
|
optimizer_pea = ScaledAdam(
|
||||||
|
pea_param,
|
||||||
|
lr=params.pea_lr,
|
||||||
clipping_scale=5.0,
|
clipping_scale=5.0,
|
||||||
parameters_names=[adapter_names],
|
parameters_names=[pea_names],
|
||||||
)
|
)
|
||||||
|
scheduler_pea = Eden(optimizer_pea, 10000, 7)
|
||||||
scheduler_adapter = Eden(optimizer_adapter, 10000, 7) #params.lr_batche, params.lr_epochs)
|
optimizer, scheduler = optimizer_pea, scheduler_pea
|
||||||
optimizer, scheduler = optimizer_adapter, scheduler_adapter
|
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
train_cuts = librispeech.vox_cuts(option=params.spk_id)
|
train_cuts = librispeech.vox_cuts(option=params.spk_id)
|
||||||
@ -1643,6 +1575,7 @@ def run_adapter(rank, world_size, args, wb=None):
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
wb=wb,
|
wb=wb,
|
||||||
|
lora_modules=lora_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
@ -1746,13 +1679,13 @@ def main():
|
|||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
mp.spawn(run if not args.add_adapter else run_adapter,
|
mp.spawn(run if not args.pea else run_pea,
|
||||||
args=(world_size, args, wb),
|
args=(world_size, args, wb),
|
||||||
nprocs=world_size,
|
nprocs=world_size,
|
||||||
join=True
|
join=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if args.add_adapter: run_adapter(rank=0, world_size=1, args=args, wb=wb)
|
if args.pea: run_pea(rank=0, world_size=1, args=args, wb=wb)
|
||||||
else: run(rank=0, world_size=1, args=args, wb=wb)
|
else: run(rank=0, world_size=1, args=args, wb=wb)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
Loading…
x
Reference in New Issue
Block a user