From dcab1aee4e250dfa5c0993c3f57738aa08541056 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 May 2022 17:47:30 +0800 Subject: [PATCH] make biases learn faster in a different way. --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 837e4297f..e490ca20d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -179,7 +179,7 @@ class ScaledLinear(nn.Linear): with torch.no_grad(): self.weight[:] *= initial_scale if self.bias is not None: - self.bias[:] *= initial_scale + self.bias[:] *= initial_scale * 4.0 def get_weight(self): # not needed any more but kept for back compatibility return self.weight diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index daccbee04..c5391043e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -871,11 +871,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - - params_to_pass = [ {'params': [ p for (name,p) in model.named_parameters() if 'bias' not in name] }, - {'params': [ p for (name,p) in model.named_parameters() if 'bias' in name ], 'lr': params.initial_lr*2.0 } ] - - optimizer = Cain(params_to_pass, lr=params.initial_lr) + optimizer = Cain(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)