diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index b67408cbd..0e5808369 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -228,13 +228,6 @@ class ScaledAdam(BatchedOptimizer): return loss - @torch.no_grad() - def reset(self): - for d in self.state.values(): - # d should be a dict. clear all elements from it. - d.clear() - - def _init_state(self, group: dict, p: Tensor, @@ -906,8 +899,8 @@ def _test_scaled_adam(hidden_dim: int): avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - if epoch == 100 and iter == 1: - optim.reset() + #if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. #if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 6559543fc..e86e13c4c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -415,8 +415,7 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - # parameters for Noam - "warm_step": 3000, # arg given to model, not for lrate + "warm_step": 2000, "env_info": get_env_info(), } ) @@ -652,18 +651,18 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if batch_idx_train < warm_step - else 0.1 if batch_idx_train < 2 * warm_step - else 1.0 + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) + loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + simple_loss_scale * simple_loss + + pruned_loss ) assert loss.requires_grad == is_training @@ -793,14 +792,6 @@ def train_one_epoch( scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) - if params.batch_idx_train in [ params.model_warm_step, - 2 * params.model_warm_step ]: - logging.info("Resetting optimizer state due to change in loss definition.") - # we're about to start using the pruned loss, or rescale it, - # so reset the optimizer state, to avoid - # possible instability due to the squared stats becoming - # inaccurate (too small) - optimizer.reset() scaler.step(optimizer) scaler.update() optimizer.zero_grad() @@ -1043,7 +1034,6 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1136,7 +1126,6 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, - warmup: float ): from lhotse.dataset import find_pessimistic_batches @@ -1154,7 +1143,6 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=warmup, ) loss.backward() optimizer.step()