diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index af4bb577e..daccbee04 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -872,8 +872,8 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) - params_to_pass = [ {'params': [ p for (name,p) in model.named_parameters() if name != 'simple_am_proj.weight' ] }, - {'params': [ p for (name,p) in model.named_parameters() if name == 'simple_am_proj.weight' ], 'lr': params.initial_lr*0.25 } ] + params_to_pass = [ {'params': [ p for (name,p) in model.named_parameters() if 'bias' not in name] }, + {'params': [ p for (name,p) in model.named_parameters() if 'bias' in name ], 'lr': params.initial_lr*2.0 } ] optimizer = Cain(params_to_pass, lr=params.initial_lr)