mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
154 lines
6.3 KiB
Python
154 lines
6.3 KiB
Python
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
|
#
|
|
# See ../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import random
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
class Eve(Optimizer):
|
|
r"""
|
|
Implements Eve algorithm. This is a modified version of AdamW with a special
|
|
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
|
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
|
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
|
will be close to invariant to the absolute scale on the parameter matrix.
|
|
|
|
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
|
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
|
Eve is unpublished so far.
|
|
|
|
Arguments:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-3)
|
|
betas (Tuple[float, float], optional): coefficients used for computing
|
|
running averages of gradient and its square (default: (0.9, 0.999))
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-8)
|
|
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
|
this value means that the weight would decay significantly after
|
|
about 3k minibatches. Is not multiplied by learning rate, but
|
|
is conditional on RMS-value of parameter being > target_rms.
|
|
target_rms (float, optional): target root-mean-square value of
|
|
parameters, if they fall below this we will stop applying weight decay.
|
|
|
|
|
|
.. _Adam\: A Method for Stochastic Optimization:
|
|
https://arxiv.org/abs/1412.6980
|
|
.. _Decoupled Weight Decay Regularization:
|
|
https://arxiv.org/abs/1711.05101
|
|
.. _On the Convergence of Adam and Beyond:
|
|
https://openreview.net/forum?id=ryQu7f-RZ
|
|
"""
|
|
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
|
|
weight_decay=3e-4, target_rms=0.1):
|
|
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
if not 0 <= weight_decay <= 0.1:
|
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
|
if not 0 < target_rms <= 10.0:
|
|
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
weight_decay=weight_decay,
|
|
target_rms=target_rms)
|
|
super(Eve, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super(Eve, self).__setstate__(state)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Perform optimization step
|
|
grad = p.grad
|
|
if grad.is_sparse:
|
|
raise RuntimeError('AdamW does not support sparse gradients')
|
|
|
|
state = self.state[p]
|
|
|
|
# State initialization
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
# Exponential moving average of gradient values
|
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
# Exponential moving average of squared gradient values
|
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
|
|
beta1, beta2 = group['betas']
|
|
|
|
state['step'] += 1
|
|
bias_correction1 = 1 - beta1 ** state['step']
|
|
bias_correction2 = 1 - beta2 ** state['step']
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps'])
|
|
|
|
step_size = group['lr'] / bias_correction1
|
|
target_rms = group['target_rms']
|
|
weight_decay = group['weight_decay']
|
|
delta = exp_avg / denom
|
|
|
|
if p.numel() > 1:
|
|
# avoid applying this weight-decay on "scaling factors" (which are scalar).
|
|
is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5)))
|
|
p.mul_(1 - (weight_decay * is_below_target_rms))
|
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
|
return loss
|
|
|
|
# Note on avg-change per epoch..
|
|
# suppose epoch is 4k iters.
|
|
# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1,
|
|
# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch)
|
|
# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04.
|
|
#
|
|
# .. 6e-05 is 1/5 of that...
|