diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 9c4646009..05b0efe03 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -16,7 +16,7 @@ from typing import List, Optional, Union, Tuple, List - +from lhotse.utils import fix_random_seed import torch from torch import Tensor from torch.optim import Optimizer @@ -652,7 +652,7 @@ class Cain(Optimizer): orig_dim = dim step = state["step"] must_store_stats, must_zero_stats = self._must_store_stats(step) - if must_store_stats and allow_store_stats: + if must_store_stats and allow_store_stats and forward: # store stats for 100 iters preceding when we estimate the # transform. stats_name = f"cov_{orig_dim}" @@ -673,7 +673,7 @@ class Cain(Optimizer): dtype=grad.dtype) must_estimate_transform = self._must_estimate_transform(step) - if must_estimate_transform: + if must_estimate_transform and forward: stats_name = f"cov_{orig_dim}" cov = state[stats_name] l, U = cov.symeig(eigenvectors=True) @@ -832,6 +832,152 @@ class LRScheduler(object): f" of group {group} to {lr:.4e}." ) +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if not 0 <= weight_decay <= 0.1: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "AdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss + class Eden(LRScheduler): """ @@ -897,13 +1043,13 @@ def _test_eden(): print("state dict = ", scheduler.state_dict()) -def _test_abel(): +def _test_eve_cain(): import timeit from scaling import ScaledLinear E = 100 B = 4 T = 2 - print("in test_abel") + print("in test_eve_cain") device = torch.device('cuda') dtype = torch.float32 @@ -913,7 +1059,8 @@ def _test_abel(): input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - for iter in [1,0]: + for iter in [0,1]: + fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear m = torch.nn.Sequential(Linear(E, 200), torch.nn.ReLU(), @@ -922,7 +1069,7 @@ def _test_abel(): train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] - if iter == 0: optim = Abel(m.parameters(), lr=0.003) + if iter == 0: optim = Eve(m.parameters(), lr=0.003) else: optim = Cain(m.parameters(), lr=0.003) scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False) @@ -974,5 +1121,5 @@ def _test_abel(): if __name__ == "__main__": - _test_abel() + _test_eve_cain() #_test_eden()