mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
from local
This commit is contained in:
parent
7a4cc6f9b6
commit
f11cfd4ebe
BIN
egs/tedlium3/ASR/.prepare.sh.swp
Normal file
BIN
egs/tedlium3/ASR/.prepare.sh.swp
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -366,3 +366,8 @@ class TedLiumAsrDataModule:
|
|||||||
def test_cuts(self) -> CutSet:
|
def test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def user_test_cuts(self, user) -> CutSet:
|
||||||
|
logging.info("About to get test cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{user}.jsonl.gz")
|
||||||
|
@ -160,6 +160,20 @@ def add_pea_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=False,
|
default=False,
|
||||||
help="Low Rank Adaptation training for PEA"
|
help="Low Rank Adaptation training for PEA"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rank",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Rank for Low Rank Adaptation. [default = 2]"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-alpha",
|
||||||
|
type=float,
|
||||||
|
default=10000.,
|
||||||
|
help="alpha for Low Rank Adaptation. alpha will be multiplied to lora output. [default = 10000.]"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pea-lr",
|
"--pea-lr",
|
||||||
@ -1114,9 +1128,9 @@ def train_one_epoch(
|
|||||||
params.cur_batch_idx = batch_idx
|
params.cur_batch_idx = batch_idx
|
||||||
del params.cur_batch_idx
|
del params.cur_batch_idx
|
||||||
|
|
||||||
if rank == 0:
|
#if rank == 0:
|
||||||
for i, lora in enumerate(lora_modules):
|
# for i, lora in enumerate(lora_modules):
|
||||||
lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
|
# lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if batch_idx % 100 == 0 and params.use_fp16:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
@ -1495,26 +1509,39 @@ def run_pea(rank, world_size, args, wb=None):
|
|||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
lora_modules = []
|
pea_names = []
|
||||||
for modules in model.modules():
|
pea_param = []
|
||||||
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
|
lora_modules = None
|
||||||
for module in modules.modules():
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
if args.bitfit:
|
||||||
lora_modules.append(LoRAHook(
|
for n, p in model.named_parameters():
|
||||||
|
if 'encoder.encoder' in n and 'bias' in n and 'feature_extractor' not in n:
|
||||||
|
pea_names.append(n)
|
||||||
|
pea_param.append(p)
|
||||||
|
else:
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
elif args.lora:
|
||||||
|
lora_modules = []
|
||||||
|
for modules in model.modules():
|
||||||
|
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
|
||||||
|
for module in modules.modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
lora_modules.append(
|
||||||
|
LoRAHook(
|
||||||
module,
|
module,
|
||||||
embedding_dim=args.encoder_dim,
|
embedding_dim=args.encoder_dim,
|
||||||
rank=args.rank,
|
rank=args.rank,
|
||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
logging.info("Using DDP for LoRA")
|
||||||
|
for module in lora_modules:
|
||||||
|
module.lora = module.lora.to(device)
|
||||||
|
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
logging.info("Using DDP for LoRA")
|
|
||||||
for module in lora_modules:
|
|
||||||
module.lora = module.lora.to(device)
|
|
||||||
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
|
|
||||||
|
|
||||||
pea_names = []
|
|
||||||
pea_param = []
|
|
||||||
for i, module in enumerate(lora_modules):
|
for i, module in enumerate(lora_modules):
|
||||||
for n, p in module.lora.named_parameters():
|
for n, p in module.lora.named_parameters():
|
||||||
new_n = str(i) + n
|
new_n = str(i) + n
|
||||||
|
Loading…
x
Reference in New Issue
Block a user