First draft of new approach to learning rates + init

This commit is contained in:
Daniel Povey 2022-04-04 20:21:34 +08:00
parent 4929e4cf32
commit 72f4a673b1
4 changed files with 299 additions and 104 deletions

View File

@ -1017,93 +1017,6 @@ class Conv2dSubsampling(nn.Module):
return x
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* self.warmup ** (-0.5 - -0.333)
* min(step ** (-0.333), step * self.warmup ** (-1.333))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
if __name__ == '__main__':
feature_dim = 50

View File

@ -0,0 +1,254 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from torch.optim import Optimizer
class Eve(Optimizer):
r"""
Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular specified value (generally 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Eve is unpublished so far.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
target_rms=0.1):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(lr=lr, betas=betas, eps=eps,
target_rms=target_rms)
super(Eve, self).__init__(params, defaults)
def __setstate__(self, state):
super(Eve, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
target_rms = group['target_rms']
delta = exp_avg / denom
# we'll be doing: p += delta * step_size.
# In the normal case delta_rms (the rms value of the elements of
# delta) will be very close to 1.0, but we compute it here so
# that if we don't use a particular parameter, its value won't
# shrink to zero.
# delta_var is the expected change in the variance of the parameter
# values, i.e. of E[param_elem^2], due to this step. It will
# be close to 1.
# Let us define:
# delta_var_from_update = (delta**2).mean() * step_size * step_size
# Suppose we are going to shrinkage with a small value epsilon (not the
# same as the eps above!), i.e. param *= (1-epsilon). Then
# if E[param_elem^2] == target_rms^2,
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2),
# which we can put as:
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
# Setting delta_var_from_shrinkage = -delta_var_from_update
# because we want them to cancel,
# delta_var_from_update = 2 epsilon target_rms^2, or:
# epsilon = delta_var_from_update / (2 * target_rms^2)
# = (delta**2).mean() * 0.5 * (step_size / target_rms)**2.
# Note: step_size is close to the learning rate. For an example, if
# lr = 1.0e-04 and target_rms == 0.1, then in the normal case where
# (delta**2).mean() == 1, we will have:
# epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06.
# Note that this is close to the "traditional" value used for weight
# decay.
# this is the weight-decay amount...
weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2)
p.mul_(1 - weight_decay)
p.add_(delta, alpha=-step_size)
return loss
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* self.warmup ** (-0.5 - -0.333)
* min(step ** (-0.333), step * self.warmup ** (-1.333))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -158,7 +158,10 @@ class ScaledLinear(nn.Linear):
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self, initial_speed: float):
std = 0.01 / initial_speed
# we plan to use Eve as the optimizer, which will eventually make the stddev approach
# 0.1 as that's the target_rms we set, but we initialize with a larger stddev
# to have the same effect as a warm-up period.
std = 0.5 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
@ -196,7 +199,7 @@ class ScaledConv1d(nn.Conv1d):
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.01 / initial_speed
std = 0.5 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
@ -241,7 +244,7 @@ class ScaledConv2d(nn.Conv2d):
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.01 / initial_speed
std = 0.5 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
@ -476,9 +479,8 @@ class ScaledEmbedding(nn.Module):
self.reset_parameters(initial_speed)
def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.01 / initial_speed
std = 0.5 / initial_speed
nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())

View File

@ -28,7 +28,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 300 \
--lr-factor 1.5
--initial-lr 0.002 \
--lr-decay-steps 10000 \
--num-lr-decays 4
"""
@ -52,6 +55,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eve
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -141,17 +145,24 @@ def get_parser():
)
parser.add_argument(
"--lr-factor",
"--initial-lr",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
default=0.002,
help="The initial learning rate",
)
parser.add_argument(
"--warm-step",
"--lr-decay-steps",
type=float,
default=60000,
help="The number of warmup steps for the (modified) Noam optimizer",
default=5000,
help="The number of steps before we decay (halve) the learning rate",
)
parser.add_argument(
"--num-lr-decays",
type=float,
default=4,
help="The total number of times we decay (halve) the learning rate"
)
parser.add_argument(
@ -426,6 +437,7 @@ def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
@ -449,6 +461,7 @@ def save_checkpoint(
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=sampler,
rank=rank,
)
@ -574,6 +587,7 @@ def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
@ -594,6 +608,8 @@ def train_one_epoch(
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
@ -636,6 +652,7 @@ def train_one_epoch(
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
if params.print_diagnostics and batch_idx == 5:
return
@ -651,6 +668,7 @@ def train_one_epoch(
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
rank=rank,
)
@ -756,17 +774,24 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
optimizer = Eve(
model.parameters(),
model_size=params.encoder_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)
lr=params.initial_lr, betas=(0.9, 0.98),
eps=1e-9, target_rms=0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
[ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ],
gamma=0.5)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if checkpoints and "scheduler" in checkpoints:
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
@ -839,6 +864,7 @@ def run(rank, world_size, args):
params=params,
model=model,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,