from local

This commit is contained in:
dohe0342 2023-01-09 20:23:52 +09:00
parent cc5c3eff53
commit ff62d9d54f
3 changed files with 6 additions and 1 deletions

View File

@ -982,6 +982,11 @@ def run(rank, world_size, args):
transducer_model.load_state_dict(pre_trained_model, strict=True)
model = get_interformer_model(transducer_model.encoder, params)
for n, p in model.named_parameters():
if 'pt_encoder' in n:
p.requires_grad = False
else:
print(n)
'''
for n, p in model.named_parameters():
if 'layer' not in n:
@ -1016,7 +1021,7 @@ def run(rank, world_size, args):
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
optimizer = Eve(model.parameters(), lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)