mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Fixes to cain
This commit is contained in:
parent
6085ab64ef
commit
1edc0fa841
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
from typing import List, Optional, Union, Tuple, List
|
||||
|
||||
from lhotse.utils import fix_random_seed
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@ -652,7 +652,7 @@ class Cain(Optimizer):
|
||||
orig_dim = dim
|
||||
step = state["step"]
|
||||
must_store_stats, must_zero_stats = self._must_store_stats(step)
|
||||
if must_store_stats and allow_store_stats:
|
||||
if must_store_stats and allow_store_stats and forward:
|
||||
# store stats for 100 iters preceding when we estimate the
|
||||
# transform.
|
||||
stats_name = f"cov_{orig_dim}"
|
||||
@ -673,7 +673,7 @@ class Cain(Optimizer):
|
||||
dtype=grad.dtype)
|
||||
|
||||
must_estimate_transform = self._must_estimate_transform(step)
|
||||
if must_estimate_transform:
|
||||
if must_estimate_transform and forward:
|
||||
stats_name = f"cov_{orig_dim}"
|
||||
cov = state[stats_name]
|
||||
l, U = cov.symeig(eigenvectors=True)
|
||||
@ -832,6 +832,152 @@ class LRScheduler(object):
|
||||
f" of group {group} to {lr:.4e}."
|
||||
)
|
||||
|
||||
class Eve(Optimizer):
|
||||
"""
|
||||
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
||||
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||
will be close to invariant to the absolute scale on the parameter matrix.
|
||||
|
||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
||||
Eve is unpublished so far.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
||||
this value means that the weight would decay significantly after
|
||||
about 3k minibatches. Is not multiplied by learning rate, but
|
||||
is conditional on RMS-value of parameter being > target_rms.
|
||||
target_rms (float, optional): target root-mean-square value of
|
||||
parameters, if they fall below this we will stop applying weight decay.
|
||||
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||
)
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0 <= weight_decay <= 0.1:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
)
|
||||
if not 0 < target_rms <= 10.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
target_rms=target_rms,
|
||||
)
|
||||
super(Eve, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Eve, self).__setstate__(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
# Perform optimization step
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"AdamW does not support sparse gradients"
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
||||
group["eps"]
|
||||
)
|
||||
|
||||
step_size = group["lr"] / bias_correction1
|
||||
target_rms = group["target_rms"]
|
||||
weight_decay = group["weight_decay"]
|
||||
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors"
|
||||
# (which are scalar).
|
||||
is_above_target_rms = p.norm() > (
|
||||
target_rms * (p.numel() ** 0.5)
|
||||
)
|
||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class Eden(LRScheduler):
|
||||
"""
|
||||
@ -897,13 +1043,13 @@ def _test_eden():
|
||||
print("state dict = ", scheduler.state_dict())
|
||||
|
||||
|
||||
def _test_abel():
|
||||
def _test_eve_cain():
|
||||
import timeit
|
||||
from scaling import ScaledLinear
|
||||
E = 100
|
||||
B = 4
|
||||
T = 2
|
||||
print("in test_abel")
|
||||
print("in test_eve_cain")
|
||||
device = torch.device('cuda')
|
||||
dtype = torch.float32
|
||||
|
||||
@ -913,7 +1059,8 @@ def _test_abel():
|
||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
|
||||
for iter in [1,0]:
|
||||
for iter in [0,1]:
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
m = torch.nn.Sequential(Linear(E, 200),
|
||||
torch.nn.ReLU(),
|
||||
@ -922,7 +1069,7 @@ def _test_abel():
|
||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||
|
||||
if iter == 0: optim = Abel(m.parameters(), lr=0.003)
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
else: optim = Cain(m.parameters(), lr=0.003)
|
||||
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
|
||||
|
||||
@ -974,5 +1121,5 @@ def _test_abel():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_abel()
|
||||
_test_eve_cain()
|
||||
#_test_eden()
|
||||
|
Loading…
x
Reference in New Issue
Block a user