Fixes to cain

This commit is contained in:
Daniel Povey 2022-05-19 22:21:41 +08:00
parent 6085ab64ef
commit 1edc0fa841

View File

@ -16,7 +16,7 @@
from typing import List, Optional, Union, Tuple, List from typing import List, Optional, Union, Tuple, List
from lhotse.utils import fix_random_seed
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
@ -652,7 +652,7 @@ class Cain(Optimizer):
orig_dim = dim orig_dim = dim
step = state["step"] step = state["step"]
must_store_stats, must_zero_stats = self._must_store_stats(step) must_store_stats, must_zero_stats = self._must_store_stats(step)
if must_store_stats and allow_store_stats: if must_store_stats and allow_store_stats and forward:
# store stats for 100 iters preceding when we estimate the # store stats for 100 iters preceding when we estimate the
# transform. # transform.
stats_name = f"cov_{orig_dim}" stats_name = f"cov_{orig_dim}"
@ -673,7 +673,7 @@ class Cain(Optimizer):
dtype=grad.dtype) dtype=grad.dtype)
must_estimate_transform = self._must_estimate_transform(step) must_estimate_transform = self._must_estimate_transform(step)
if must_estimate_transform: if must_estimate_transform and forward:
stats_name = f"cov_{orig_dim}" stats_name = f"cov_{orig_dim}"
cov = state[stats_name] cov = state[stats_name]
l, U = cov.symeig(eigenvectors=True) l, U = cov.symeig(eigenvectors=True)
@ -832,6 +832,152 @@ class LRScheduler(object):
f" of group {group} to {lr:.4e}." f" of group {group} to {lr:.4e}."
) )
class Eve(Optimizer):
"""
Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular target_rms (default: 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Eve is unpublished so far.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
this value means that the weight would decay significantly after
about 3k minibatches. Is not multiplied by learning rate, but
is conditional on RMS-value of parameter being > target_rms.
target_rms (float, optional): target root-mean-square value of
parameters, if they fall below this we will stop applying weight decay.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms,
)
super(Eve, self).__init__(params, defaults)
def __setstate__(self, state):
super(Eve, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
group["eps"]
)
step_size = group["lr"] / bias_correction1
target_rms = group["target_rms"]
weight_decay = group["weight_decay"]
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss
class Eden(LRScheduler): class Eden(LRScheduler):
""" """
@ -897,13 +1043,13 @@ def _test_eden():
print("state dict = ", scheduler.state_dict()) print("state dict = ", scheduler.state_dict())
def _test_abel(): def _test_eve_cain():
import timeit import timeit
from scaling import ScaledLinear from scaling import ScaledLinear
E = 100 E = 100
B = 4 B = 4
T = 2 T = 2
print("in test_abel") print("in test_eve_cain")
device = torch.device('cuda') device = torch.device('cuda')
dtype = torch.float32 dtype = torch.float32
@ -913,7 +1059,8 @@ def _test_abel():
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
for iter in [1,0]: for iter in [0,1]:
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear Linear = torch.nn.Linear if iter == 0 else ScaledLinear
m = torch.nn.Sequential(Linear(E, 200), m = torch.nn.Sequential(Linear(E, 200),
torch.nn.ReLU(), torch.nn.ReLU(),
@ -922,7 +1069,7 @@ def _test_abel():
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
if iter == 0: optim = Abel(m.parameters(), lr=0.003) if iter == 0: optim = Eve(m.parameters(), lr=0.003)
else: optim = Cain(m.parameters(), lr=0.003) else: optim = Cain(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False) scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
@ -974,5 +1121,5 @@ def _test_abel():
if __name__ == "__main__": if __name__ == "__main__":
_test_abel() _test_eve_cain()
#_test_eden() #_test_eden()