diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index edbebcceb..6f19807dc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -27,7 +27,7 @@ class Eve(Optimizer): r""" Implements Eve algorithm. This is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular specified value (generally 0.1). This is + rms of the parameters approach a particular specified value (we suggest 0.1). This is for use with networks with 'scaled' versions of modules (see scaling.py), which will be close to invariant to the absolute scale on the parameter matrix. @@ -120,7 +120,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps']) step_size = group['lr'] / bias_correction1 target_rms = group['target_rms'] @@ -141,7 +141,7 @@ class Eve(Optimizer): # Suppose we are going to shrinkage with a small value epsilon (not the # same as the eps above!), i.e. param *= (1-epsilon). Then # if E[param_elem^2] == target_rms^2, - # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2), + # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), # which we can put as: # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. # Setting delta_var_from_shrinkage = -delta_var_from_update @@ -157,98 +157,9 @@ class Eve(Optimizer): # decay. # this is the weight-decay amount... - weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2) + weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2) p.mul_(1 - weight_decay) p.add_(delta, alpha=-step_size) return loss - - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 9d074fdd4..9f73c8fbc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -48,7 +48,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer, Noam +from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -437,7 +437,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -652,7 +652,7 @@ def train_one_epoch( loss.backward() optimizer.step() optimizer.zero_grad() - lr_scheduler.step() + scheduler.step() if params.print_diagnostics and batch_idx == 5: return @@ -848,7 +848,7 @@ def run(rank, world_size, args): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - cur_lr = optimizer._rate + cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train @@ -908,12 +908,16 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. loss, _ = compute_loss( params=params, model=model, sp=sp, batch=batch, is_training=True, + warmup = 0.0 ) loss.backward() optimizer.step()