from local

This commit is contained in:
dohe0342 2023-06-09 15:55:55 +09:00
parent 7a4cc6f9b6
commit f11cfd4ebe
6 changed files with 50 additions and 18 deletions

Binary file not shown.

View File

@ -366,3 +366,8 @@ class TedLiumAsrDataModule:
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
@lru_cache()
def user_test_cuts(self, user) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{user}.jsonl.gz")

View File

@ -160,6 +160,20 @@ def add_pea_arguments(parser: argparse.ArgumentParser):
default=False,
help="Low Rank Adaptation training for PEA"
)
parser.add_argument(
"--rank",
type=int,
default=2,
help="Rank for Low Rank Adaptation. [default = 2]"
)
parser.add_argument(
"--lora-alpha",
type=float,
default=10000.,
help="alpha for Low Rank Adaptation. alpha will be multiplied to lora output. [default = 10000.]"
)
parser.add_argument(
"--pea-lr",
@ -1114,9 +1128,9 @@ def train_one_epoch(
params.cur_batch_idx = batch_idx
del params.cur_batch_idx
if rank == 0:
for i, lora in enumerate(lora_modules):
lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
#if rank == 0:
# for i, lora in enumerate(lora_modules):
# lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
@ -1495,26 +1509,39 @@ def run_pea(rank, world_size, args, wb=None):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
lora_modules = []
for modules in model.modules():
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
for module in modules.modules():
if isinstance(module, torch.nn.Linear):
lora_modules.append(LoRAHook(
pea_names = []
pea_param = []
lora_modules = None
if args.bitfit:
for n, p in model.named_parameters():
if 'encoder.encoder' in n and 'bias' in n and 'feature_extractor' not in n:
pea_names.append(n)
pea_param.append(p)
else:
p.requires_grad = False
elif args.lora:
lora_modules = []
for modules in model.modules():
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
for module in modules.modules():
if isinstance(module, torch.nn.Linear):
lora_modules.append(
LoRAHook(
module,
embedding_dim=args.encoder_dim,
rank=args.rank,
lora_alpha=args.lora_alpha,
))
)
)
if world_size > 1:
logging.info("Using DDP for LoRA")
for module in lora_modules:
module.lora = module.lora.to(device)
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
if world_size > 1:
logging.info("Using DDP for LoRA")
for module in lora_modules:
module.lora = module.lora.to(device)
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
pea_names = []
pea_param = []
for i, module in enumerate(lora_modules):
for n, p in module.lora.named_parameters():
new_n = str(i) + n