mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
First draft of new approach to learning rates + init
This commit is contained in:
parent
4929e4cf32
commit
72f4a673b1
@ -1017,93 +1017,6 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Noam(object):
|
|
||||||
"""
|
|
||||||
Implements Noam optimizer.
|
|
||||||
|
|
||||||
Proposed in
|
|
||||||
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
|
||||||
|
|
||||||
Modified from
|
|
||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
iterable of parameters to optimize or dicts defining parameter groups
|
|
||||||
model_size:
|
|
||||||
attention dimension of the transformer model
|
|
||||||
factor:
|
|
||||||
learning rate factor
|
|
||||||
warm_step:
|
|
||||||
warmup steps
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
model_size: int = 256,
|
|
||||||
factor: float = 10.0,
|
|
||||||
warm_step: int = 25000,
|
|
||||||
weight_decay=0,
|
|
||||||
) -> None:
|
|
||||||
"""Construct an Noam object."""
|
|
||||||
self.optimizer = torch.optim.Adam(
|
|
||||||
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
|
||||||
)
|
|
||||||
self._step = 0
|
|
||||||
self.warmup = warm_step
|
|
||||||
self.factor = factor
|
|
||||||
self.model_size = model_size
|
|
||||||
self._rate = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def param_groups(self):
|
|
||||||
"""Return param_groups."""
|
|
||||||
return self.optimizer.param_groups
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
"""Update parameters and rate."""
|
|
||||||
self._step += 1
|
|
||||||
rate = self.rate()
|
|
||||||
for p in self.optimizer.param_groups:
|
|
||||||
p["lr"] = rate
|
|
||||||
self._rate = rate
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
def rate(self, step=None):
|
|
||||||
"""Implement `lrate` above."""
|
|
||||||
if step is None:
|
|
||||||
step = self._step
|
|
||||||
return (
|
|
||||||
self.factor
|
|
||||||
* self.model_size ** (-0.5)
|
|
||||||
* self.warmup ** (-0.5 - -0.333)
|
|
||||||
* min(step ** (-0.333), step * self.warmup ** (-1.333))
|
|
||||||
)
|
|
||||||
|
|
||||||
def zero_grad(self):
|
|
||||||
"""Reset gradient."""
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
"""Return state_dict."""
|
|
||||||
return {
|
|
||||||
"_step": self._step,
|
|
||||||
"warmup": self.warmup,
|
|
||||||
"factor": self.factor,
|
|
||||||
"model_size": self.model_size,
|
|
||||||
"_rate": self._rate,
|
|
||||||
"optimizer": self.optimizer.state_dict(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
"""Load state_dict."""
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
if key == "optimizer":
|
|
||||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
||||||
else:
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
|
254
egs/librispeech/ASR/pruned_transducer_stateless2/optim.py
Normal file
254
egs/librispeech/ASR/pruned_transducer_stateless2/optim.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
||||||
|
#
|
||||||
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class Eve(Optimizer):
|
||||||
|
r"""
|
||||||
|
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||||
|
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||||
|
rms of the parameters approach a particular specified value (generally 0.1). This is
|
||||||
|
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||||
|
will be close to invariant to the absolute scale on the parameter matrix.
|
||||||
|
|
||||||
|
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||||
|
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
||||||
|
Eve is unpublished so far.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining
|
||||||
|
parameter groups
|
||||||
|
lr (float, optional): learning rate (default: 1e-3)
|
||||||
|
betas (Tuple[float, float], optional): coefficients used for computing
|
||||||
|
running averages of gradient and its square (default: (0.9, 0.999))
|
||||||
|
eps (float, optional): term added to the denominator to improve
|
||||||
|
numerical stability (default: 1e-8)
|
||||||
|
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||||
|
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||||
|
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||||
|
(default: False)
|
||||||
|
|
||||||
|
.. _Adam\: A Method for Stochastic Optimization:
|
||||||
|
https://arxiv.org/abs/1412.6980
|
||||||
|
.. _Decoupled Weight Decay Regularization:
|
||||||
|
https://arxiv.org/abs/1711.05101
|
||||||
|
.. _On the Convergence of Adam and Beyond:
|
||||||
|
https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
|
||||||
|
target_rms=0.1):
|
||||||
|
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||||
|
if not 0 < target_rms <= 10.0:
|
||||||
|
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
target_rms=target_rms)
|
||||||
|
super(Eve, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(Eve, self).__setstate__(state)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Perform optimization step
|
||||||
|
grad = p.grad
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError('AdamW does not support sparse gradients')
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
|
||||||
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
|
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
bias_correction1 = 1 - beta1 ** state['step']
|
||||||
|
bias_correction2 = 1 - beta2 ** state['step']
|
||||||
|
|
||||||
|
# Decay the first and second moment running average coefficient
|
||||||
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||||
|
|
||||||
|
step_size = group['lr'] / bias_correction1
|
||||||
|
target_rms = group['target_rms']
|
||||||
|
delta = exp_avg / denom
|
||||||
|
|
||||||
|
# we'll be doing: p += delta * step_size.
|
||||||
|
# In the normal case delta_rms (the rms value of the elements of
|
||||||
|
# delta) will be very close to 1.0, but we compute it here so
|
||||||
|
# that if we don't use a particular parameter, its value won't
|
||||||
|
# shrink to zero.
|
||||||
|
# delta_var is the expected change in the variance of the parameter
|
||||||
|
# values, i.e. of E[param_elem^2], due to this step. It will
|
||||||
|
# be close to 1.
|
||||||
|
|
||||||
|
# Let us define:
|
||||||
|
# delta_var_from_update = (delta**2).mean() * step_size * step_size
|
||||||
|
|
||||||
|
# Suppose we are going to shrinkage with a small value epsilon (not the
|
||||||
|
# same as the eps above!), i.e. param *= (1-epsilon). Then
|
||||||
|
# if E[param_elem^2] == target_rms^2,
|
||||||
|
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2),
|
||||||
|
# which we can put as:
|
||||||
|
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
|
||||||
|
# Setting delta_var_from_shrinkage = -delta_var_from_update
|
||||||
|
# because we want them to cancel,
|
||||||
|
# delta_var_from_update = 2 epsilon target_rms^2, or:
|
||||||
|
# epsilon = delta_var_from_update / (2 * target_rms^2)
|
||||||
|
# = (delta**2).mean() * 0.5 * (step_size / target_rms)**2.
|
||||||
|
# Note: step_size is close to the learning rate. For an example, if
|
||||||
|
# lr = 1.0e-04 and target_rms == 0.1, then in the normal case where
|
||||||
|
# (delta**2).mean() == 1, we will have:
|
||||||
|
# epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06.
|
||||||
|
# Note that this is close to the "traditional" value used for weight
|
||||||
|
# decay.
|
||||||
|
|
||||||
|
# this is the weight-decay amount...
|
||||||
|
weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2)
|
||||||
|
|
||||||
|
p.mul_(1 - weight_decay)
|
||||||
|
p.add_(delta, alpha=-step_size)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Noam(object):
|
||||||
|
"""
|
||||||
|
Implements Noam optimizer.
|
||||||
|
|
||||||
|
Proposed in
|
||||||
|
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
||||||
|
|
||||||
|
Modified from
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
iterable of parameters to optimize or dicts defining parameter groups
|
||||||
|
model_size:
|
||||||
|
attention dimension of the transformer model
|
||||||
|
factor:
|
||||||
|
learning rate factor
|
||||||
|
warm_step:
|
||||||
|
warmup steps
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
model_size: int = 256,
|
||||||
|
factor: float = 10.0,
|
||||||
|
warm_step: int = 25000,
|
||||||
|
weight_decay=0,
|
||||||
|
) -> None:
|
||||||
|
"""Construct an Noam object."""
|
||||||
|
self.optimizer = torch.optim.Adam(
|
||||||
|
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
self._step = 0
|
||||||
|
self.warmup = warm_step
|
||||||
|
self.factor = factor
|
||||||
|
self.model_size = model_size
|
||||||
|
self._rate = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
"""Return param_groups."""
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Update parameters and rate."""
|
||||||
|
self._step += 1
|
||||||
|
rate = self.rate()
|
||||||
|
for p in self.optimizer.param_groups:
|
||||||
|
p["lr"] = rate
|
||||||
|
self._rate = rate
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
def rate(self, step=None):
|
||||||
|
"""Implement `lrate` above."""
|
||||||
|
if step is None:
|
||||||
|
step = self._step
|
||||||
|
return (
|
||||||
|
self.factor
|
||||||
|
* self.model_size ** (-0.5)
|
||||||
|
* self.warmup ** (-0.5 - -0.333)
|
||||||
|
* min(step ** (-0.333), step * self.warmup ** (-1.333))
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
"""Reset gradient."""
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Return state_dict."""
|
||||||
|
return {
|
||||||
|
"_step": self._step,
|
||||||
|
"warmup": self.warmup,
|
||||||
|
"factor": self.factor,
|
||||||
|
"model_size": self.model_size,
|
||||||
|
"_rate": self._rate,
|
||||||
|
"optimizer": self.optimizer.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Load state_dict."""
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key == "optimizer":
|
||||||
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||||
|
else:
|
||||||
|
setattr(self, key, value)
|
@ -158,7 +158,10 @@ class ScaledLinear(nn.Linear):
|
|||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
|
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.01 / initial_speed
|
# we plan to use Eve as the optimizer, which will eventually make the stddev approach
|
||||||
|
# 0.1 as that's the target_rms we set, but we initialize with a larger stddev
|
||||||
|
# to have the same effect as a warm-up period.
|
||||||
|
std = 0.5 / initial_speed
|
||||||
a = (3 ** 0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
@ -196,7 +199,7 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.01 / initial_speed
|
std = 0.5 / initial_speed
|
||||||
a = (3 ** 0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
@ -241,7 +244,7 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.01 / initial_speed
|
std = 0.5 / initial_speed
|
||||||
a = (3 ** 0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
@ -476,9 +479,8 @@ class ScaledEmbedding(nn.Module):
|
|||||||
self.reset_parameters(initial_speed)
|
self.reset_parameters(initial_speed)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||||
std = 0.01 / initial_speed
|
std = 0.5 / initial_speed
|
||||||
nn.init.normal_(self.weight, std=std)
|
nn.init.normal_(self.weight, std=std)
|
||||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
||||||
|
|
||||||
|
@ -28,7 +28,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir pruned_transducer_stateless2/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300 \
|
--max-duration 300 \
|
||||||
--lr-factor 1.5
|
--initial-lr 0.002 \
|
||||||
|
--lr-decay-steps 10000 \
|
||||||
|
--num-lr-decays 4
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -52,6 +55,7 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
from optim import Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -141,17 +145,24 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-factor",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=5.0,
|
default=0.002,
|
||||||
help="The lr_factor for Noam optimizer",
|
help="The initial learning rate",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warm-step",
|
"--lr-decay-steps",
|
||||||
type=float,
|
type=float,
|
||||||
default=60000,
|
default=5000,
|
||||||
help="The number of warmup steps for the (modified) Noam optimizer",
|
help="The number of steps before we decay (halve) the learning rate",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-lr-decays",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="The total number of times we decay (halve) the learning rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -426,6 +437,7 @@ def save_checkpoint(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -449,6 +461,7 @@ def save_checkpoint(
|
|||||||
model=model,
|
model=model,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
@ -574,6 +587,7 @@ def train_one_epoch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
|
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
@ -594,6 +608,8 @@ def train_one_epoch(
|
|||||||
The model for training.
|
The model for training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer we are using.
|
The optimizer we are using.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler, we call step() every step.
|
||||||
train_dl:
|
train_dl:
|
||||||
Dataloader for the training dataset.
|
Dataloader for the training dataset.
|
||||||
valid_dl:
|
valid_dl:
|
||||||
@ -636,6 +652,7 @@ def train_one_epoch(
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -651,6 +668,7 @@ def train_one_epoch(
|
|||||||
model=model,
|
model=model,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
@ -756,17 +774,24 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Eve(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
model_size=params.encoder_dim,
|
lr=params.initial_lr, betas=(0.9, 0.98),
|
||||||
factor=params.lr_factor,
|
eps=1e-9, target_rms=0.1)
|
||||||
warm_step=params.warm_step,
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
)
|
optimizer,
|
||||||
|
[ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ],
|
||||||
|
gamma=0.5)
|
||||||
|
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
logging.info("Loading optimizer state dict")
|
logging.info("Loading optimizer state dict")
|
||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
|
if checkpoints and "scheduler" in checkpoints:
|
||||||
|
logging.info("Loading scheduler state dict")
|
||||||
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
@ -839,6 +864,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user