diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index eb7776938..b17ebba7c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -27,7 +27,7 @@ class Eve(Optimizer): r""" Implements Eve algorithm. This is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular specified value (we suggest 0.1). This is + rms of the parameters approach a particular target_rms (default: 0.1). This is for use with networks with 'scaled' versions of modules (see scaling.py), which will be close to invariant to the absolute scale on the parameter matrix. @@ -43,10 +43,13 @@ class Eve(Optimizer): running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -57,7 +60,7 @@ class Eve(Optimizer): """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - target_rms=0.1): + weight_decay=3e-4, target_rms=0.1): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -67,9 +70,12 @@ class Eve(Optimizer): raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, target_rms=target_rms) super(Eve, self).__init__(params, defaults) @@ -94,6 +100,9 @@ class Eve(Optimizer): if p.grad is None: continue + + + # Perform optimization step grad = p.grad if grad.is_sparse: @@ -124,46 +133,21 @@ class Eve(Optimizer): step_size = group['lr'] / bias_correction1 target_rms = group['target_rms'] + weight_decay = group['weight_decay'] delta = exp_avg / denom - # we'll be doing: p += delta * step_size. - # In the normal case delta_rms (the rms value of the elements of - # delta) will be very close to 1.0, but we compute it here so - # that if we don't use a particular parameter, its value won't - # shrink to zero. - # delta_var is the expected change in the variance of the parameter - # values, i.e. of E[param_elem^2], due to this step. It will - # be close to 1. - - # Let us define: - # delta_var_from_update = (delta**2).mean() * step_size * step_size - - # Suppose we are going to shrinkage with a small value epsilon (not the - # same as the eps above!), i.e. param *= (1-epsilon). Then - # if E[param_elem^2] == target_rms^2 (because we desire equilibrium when - # the RMS of the parameters equals target_rms), it follows that - # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), - # which we can put as: - # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. - # Setting delta_var_from_shrinkage = -delta_var_from_update - # because we want them to cancel, - # delta_var_from_update = 2 epsilon target_rms^2, or: - # epsilon = delta_var_from_update / (2 * target_rms^2) - # = (delta**2).mean() * 0.5 * (step_size / target_rms)**2. - # Note: step_size is close to the learning rate. For an example, if - # lr = 1.0e-04 and target_rms == 0.1, then in the normal case where - # (delta**2).mean() == 1, we will have: - # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. - # Note that this is close to the "traditional" value used for weight - # decay. - # - # this is the weight-decay amount... - # - # Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive - # frames being correlated. I have to figure out the justification. - weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1))) - - p.mul_(1 - weight_decay) - p.add_(delta, alpha=-step_size) + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" (which are scalar). + is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5))) + p.mul_(1 - (weight_decay * is_below_target_rms)) + p.addcdiv_(exp_avg, denom, value=-step_size) return loss + +# Note on avg-change per epoch.. +# suppose epoch is 4k iters. +# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, +# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) +# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. +# +# .. 6e-05 is 1/5 of that... diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 348e2dd47..1340e0950 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -152,17 +152,17 @@ def get_parser(): ) parser.add_argument( - "--lr-decay-steps", + "--lr-num-steps", type=float, - default=5000, - help="The number of steps before we decay (halve) the learning rate", + default=3000, + help="Number of steps before we start to significantly decay the learning rate", ) parser.add_argument( - "--num-lr-decays", - type=int, - default=4, - help="The total number of times we decay (halve) the learning rate" + "--lr-power", + type=float, + default=0.5, + help="Power in LR-setting rule", ) parser.add_argument( @@ -781,12 +781,10 @@ def run(rank, world_size, args): optimizer = Eve( model.parameters(), lr=params.initial_lr, betas=(0.9, 0.98), - eps=1e-9, target_rms=0.1) - scheduler = torch.optim.lr_scheduler.MultiStepLR( + eps=1e-9, weight_decay=3e-04, target_rms=0.1) + scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - [ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ], - gamma=0.5) - + lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power)) if checkpoints and "optimizer" in checkpoints: