From ea5cd69e3b29b85403a247e92bfe017896009871 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jun 2022 20:36:36 +0800 Subject: [PATCH] Possibly fix bug RE learning rate --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 14fb08033..da7cde00c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -980,6 +980,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float ): from lhotse.dataset import find_pessimistic_batches @@ -1082,9 +1084,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1092,7 +1091,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step()