Slow down learning of simple_am_proj.weight by 0.5

This commit is contained in:
Daniel Povey 2022-05-21 17:06:24 +08:00
parent 11eac9089e
commit b7adb6d738

View File

@ -871,7 +871,11 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = Cain(model.parameters(), lr=params.initial_lr)
params_to_pass = [ {'params': [ p for (name,p) in model.named_parameters() if name != 'simple_am_proj.weight' ] },
{'params': [ p for (name,p) in model.named_parameters() if name == 'simple_am_proj.weight' ], 'lr': params.initial_lr*0.25 } ]
optimizer = Cain(params_to_pass, lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)