mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
from local
This commit is contained in:
parent
7a4cc6f9b6
commit
f11cfd4ebe
BIN
egs/tedlium3/ASR/.prepare.sh.swp
Normal file
BIN
egs/tedlium3/ASR/.prepare.sh.swp
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -366,3 +366,8 @@ class TedLiumAsrDataModule:
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def user_test_cuts(self, user) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{user}.jsonl.gz")
|
||||
|
@ -160,6 +160,20 @@ def add_pea_arguments(parser: argparse.ArgumentParser):
|
||||
default=False,
|
||||
help="Low Rank Adaptation training for PEA"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Rank for Low Rank Adaptation. [default = 2]"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora-alpha",
|
||||
type=float,
|
||||
default=10000.,
|
||||
help="alpha for Low Rank Adaptation. alpha will be multiplied to lora output. [default = 10000.]"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pea-lr",
|
||||
@ -1114,9 +1128,9 @@ def train_one_epoch(
|
||||
params.cur_batch_idx = batch_idx
|
||||
del params.cur_batch_idx
|
||||
|
||||
if rank == 0:
|
||||
for i, lora in enumerate(lora_modules):
|
||||
lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
|
||||
#if rank == 0:
|
||||
# for i, lora in enumerate(lora_modules):
|
||||
# lora.save_checkpoint(i, params.batch_idx_train, params.exp_dir)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
@ -1495,26 +1509,39 @@ def run_pea(rank, world_size, args, wb=None):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
lora_modules = []
|
||||
for modules in model.modules():
|
||||
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
|
||||
for module in modules.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
lora_modules.append(LoRAHook(
|
||||
pea_names = []
|
||||
pea_param = []
|
||||
lora_modules = None
|
||||
|
||||
if args.bitfit:
|
||||
for n, p in model.named_parameters():
|
||||
if 'encoder.encoder' in n and 'bias' in n and 'feature_extractor' not in n:
|
||||
pea_names.append(n)
|
||||
pea_param.append(p)
|
||||
else:
|
||||
p.requires_grad = False
|
||||
|
||||
elif args.lora:
|
||||
lora_modules = []
|
||||
for modules in model.modules():
|
||||
if isinstance(modules, fairseq.modules.multihead_attention.MultiheadAttention):
|
||||
for module in modules.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
lora_modules.append(
|
||||
LoRAHook(
|
||||
module,
|
||||
embedding_dim=args.encoder_dim,
|
||||
rank=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP for LoRA")
|
||||
for module in lora_modules:
|
||||
module.lora = module.lora.to(device)
|
||||
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
|
||||
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP for LoRA")
|
||||
for module in lora_modules:
|
||||
module.lora = module.lora.to(device)
|
||||
module.lora = DDP(module.lora, device_ids=[rank], find_unused_parameters=False)
|
||||
|
||||
pea_names = []
|
||||
pea_param = []
|
||||
for i, module in enumerate(lora_modules):
|
||||
for n, p in module.lora.named_parameters():
|
||||
new_n = str(i) + n
|
||||
|
Loading…
x
Reference in New Issue
Block a user