From efde3757c77c482a13324bd799ef65b4754d9d56 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Oct 2022 14:30:18 +0800 Subject: [PATCH] Reset optimizer state when we change loss function definition. --- .../ASR/pruned_transducer_stateless7/optim.py | 11 ++++-- .../ASR/pruned_transducer_stateless7/train.py | 35 ++++++++++--------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 0e5808369..b67408cbd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -228,6 +228,13 @@ class ScaledAdam(BatchedOptimizer): return loss + @torch.no_grad() + def reset(self): + for d in self.state.values(): + # d should be a dict. clear all elements from it. + d.clear() + + def _init_state(self, group: dict, p: Tensor, @@ -899,8 +906,8 @@ def _test_scaled_adam(hidden_dim: int): avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. + if epoch == 100 and iter == 1: + optim.reset() #if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index ad1348e7f..6559543fc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -399,7 +399,8 @@ def get_params() -> AttributeDict: - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. + - warm_step: The warmup period that dictates when we introduce the + pruned version of the loss. """ params = AttributeDict( { @@ -415,7 +416,7 @@ def get_params() -> AttributeDict: "feature_dim": 80, "subsampling_factor": 4, # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate + "warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -603,7 +604,6 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -636,6 +636,9 @@ def compute_loss( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) @@ -650,13 +653,13 @@ def compute_loss( lm_scale=params.lm_scale, ) # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid + # for the same amount of time (warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 if batch_idx_train < warm_step + else 0.1 if batch_idx_train < 2 * warm_step + else 1.0 ) loss = ( params.simple_loss_scale * simple_loss @@ -781,7 +784,6 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -791,15 +793,14 @@ def train_one_epoch( scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) - if params.batch_idx_train == params.model_warm_step: - # we're about to start using the pruned loss, which brings new - # modules into play, so reset the frequencies of update, to - # avoid possible instability. - try: - optimizer.reset_speedup() - logging.info("Reset speedup on optimizer") - except: - pass + if params.batch_idx_train in [ params.model_warm_step, + 2 * params.model_warm_step ]: + logging.info("Resetting optimizer state due to change in loss definition.") + # we're about to start using the pruned loss, or rescale it, + # so reset the optimizer state, to avoid + # possible instability due to the squared stats becoming + # inaccurate (too small) + optimizer.reset() scaler.step(optimizer) scaler.update() optimizer.zero_grad()