From 3b450c26828c98385901cca4716e4e651f8c4e5f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Sep 2022 14:10:42 +0800 Subject: [PATCH] Bug fix in train.py, fix optimzier name --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index ac9a6b22b..5de7102d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -66,7 +66,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from optim import Eden, PrAdam +from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -926,10 +926,8 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - optimizer = PrAdam(model.parameters(), - lr=params.initial_lr, - max_block_size=512, - lr_update_period=(400, 5000)) + optimizer = ScaledAdam(model.parameters(), + lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)