mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Add train.py
This commit is contained in:
parent
894be068e7
commit
c3a8727446
@ -3,7 +3,8 @@ import torch.distributed as dist
|
|||||||
import k2
|
import k2
|
||||||
import _k2
|
import _k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from typing import Optional, List, Tuple
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ class LmDataset(torch.utils.data.Dataset):
|
|||||||
return k2.index(self.words, sentence).values().tolist()
|
return k2.index(self.words, sentence).values().tolist()
|
||||||
|
|
||||||
|
|
||||||
def load_train_test_lm_dataset(archive_fn: str,
|
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||||
test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]:
|
test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]:
|
||||||
"""
|
"""
|
||||||
returns (train_lm_dataset, test_lm_dataset)
|
returns (train_lm_dataset, test_lm_dataset)
|
||||||
|
959
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
959
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
@ -0,0 +1,959 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# After this many warnings about infinite gradients we'll die.
|
||||||
|
inf_grad_count = 0
|
||||||
|
inf_grad_max_count = 20
|
||||||
|
|
||||||
|
class Madam(Optimizer):
|
||||||
|
r"""Madam is a modification of the Adam algorithm, with various changes
|
||||||
|
intended to support certain "common-sense" ideas and solve common
|
||||||
|
pathologies that can happen particularly in transformer-type models that
|
||||||
|
have multiplication of parameters (particularly, key and query matrices)--
|
||||||
|
these can be vulnerable to "subspace loss" where, if you have any l2
|
||||||
|
regularization, certain subspaces in the key/query space might get
|
||||||
|
regularized towards zero. We solve this with a special formula that
|
||||||
|
changes how the l2/weight-decay is done (see compute_l2_grad()).
|
||||||
|
I'll try to write the math down at some point. This formula only
|
||||||
|
applies to tensors that have at least two dimensions; for one-dimensional
|
||||||
|
tensors we simply won't do l2 regularization.
|
||||||
|
|
||||||
|
One more thing-- there is a special pathology that can sometimes afflict
|
||||||
|
models like LSTMs, where a particular element of a minibatch experiences
|
||||||
|
gradient blowup in the backward pass. We'd like to identify such cases and
|
||||||
|
fix it somehow, e.g. by removing or scaling down the gradient for that
|
||||||
|
particular minibatch. We can identify and somewhat fix this by seeing that the
|
||||||
|
gradient norm (computed over all the parameters in a parameter group) is
|
||||||
|
much more than on previous minibatches, and limiting it to (the preceding
|
||||||
|
average step size times some constant).
|
||||||
|
|
||||||
|
Like most optimization algorithms, for this to work well you need to
|
||||||
|
have an appropriate learning rate schedule, either decreasing with
|
||||||
|
time, or increasing (warm-up) and then decreasing. The LR schedule may
|
||||||
|
possibly need to decrease a little more aggressively than you would with
|
||||||
|
Adam, or at least have smaller values overall than Adam, because
|
||||||
|
the smaller parameters will mean the effective (relative) learning
|
||||||
|
rate is higher.
|
||||||
|
|
||||||
|
This is modified from PyTorch's optim/adam.py
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining
|
||||||
|
parameter groups
|
||||||
|
lr (float, optional): learning rate (default: 1e-3)
|
||||||
|
betas (Tuple[float, float], optional): coefficients used for computing
|
||||||
|
running averages of gradient and its square (default: (0.9, 0.999))
|
||||||
|
eps (float, optional): term added to the denominator to improve
|
||||||
|
numerical stability (default: 1e-8)
|
||||||
|
grad_norm_buffer_size (int, optional): Buffer size used in detecting
|
||||||
|
minibatches with unusually large gradients and scaling them down.
|
||||||
|
limit_grad_factor (float): factor by which we don't allow the
|
||||||
|
gradient to be greater than the average of previous gradients
|
||||||
|
(we'll scale the gradient down, over the whole param-group,
|
||||||
|
to enforce this). Must be greater than 1. Set to float('inf')
|
||||||
|
to disable norm clipping.
|
||||||
|
min_target_rms: A floor on the "target rms" of each Tensor, so
|
||||||
|
that Tensors that, when initialized, have less than this
|
||||||
|
rms value will have their target rms value floored to this
|
||||||
|
l2: True to enable l2 regularization
|
||||||
|
l2_period: You may set this to a value greater than one to save
|
||||||
|
computation by only periodically doing the l2 update.
|
||||||
|
We include a scaling factor in the formula so that, as far
|
||||||
|
as possible (for small learning rates) this shouldn't affect
|
||||||
|
the results. (Note: this probably isn't necessary to set,
|
||||||
|
since it turns out the update is quite fast, at least on GPU,
|
||||||
|
and the gradient clipping is actually more of a problem)
|
||||||
|
|
||||||
|
|
||||||
|
.. _Adam\: A Method for Stochastic Optimization:
|
||||||
|
https://arxiv.org/abs/1412.6980
|
||||||
|
.. _Decoupled Weight Decay Regularization:
|
||||||
|
https://arxiv.org/abs/1711.05101
|
||||||
|
.. _On the Convergence of Adam and Beyond:
|
||||||
|
https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.999),
|
||||||
|
eps: float = 1e-8,
|
||||||
|
grad_norm_buffer_size: int = 8,
|
||||||
|
limit_grad_factor: float = 2.0,
|
||||||
|
min_target_rms: float = 0.05,
|
||||||
|
l2: bool = True,
|
||||||
|
l2_period: int = 1):
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||||
|
if not (isinstance(grad_norm_buffer_size, int) and grad_norm_buffer_size > 1):
|
||||||
|
raise ValueError("Invalid grad_norm_buffer_size value: {}".format(grad_norm_buffer_size))
|
||||||
|
if not limit_grad_factor > 1.0:
|
||||||
|
raise ValueError("Invalid limit_grad_factor: {}".format(limit_grad_factor))
|
||||||
|
if not isinstance(l2, bool):
|
||||||
|
raise ValueError("Invalid l2 value: {}".format(l2))
|
||||||
|
if not l2_period >= 1:
|
||||||
|
raise ValueError("Invalid l2_period value: {}".format(l2_period))
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
grad_norm_buffer_size=grad_norm_buffer_size,
|
||||||
|
limit_grad_factor=limit_grad_factor,
|
||||||
|
l2=l2, l2_period=l2_period,
|
||||||
|
min_target_rms=min_target_rms)
|
||||||
|
super(Madam, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
grad_norm_buffer_size = group['grad_norm_buffer_size']
|
||||||
|
limit_grad_factor = group['limit_grad_factor']
|
||||||
|
min_target_rms = group['min_target_rms']
|
||||||
|
|
||||||
|
# The next 5 lists are part of the original Adam optimizer
|
||||||
|
params_with_grad = []
|
||||||
|
grads = []
|
||||||
|
exp_avgs = []
|
||||||
|
exp_avg_sqs = []
|
||||||
|
state_steps = []
|
||||||
|
|
||||||
|
# The next 3 lists are not part of the original Adam optimizer.
|
||||||
|
target_rms_values = [] # relates to weight decay. Target root-mean-square
|
||||||
|
# values of the elements of each parameter
|
||||||
|
# we are optimizing
|
||||||
|
prev_norm_stats = [] # contains Tensor with 2 elements each, the sum
|
||||||
|
# of the [sum_squared, count] of
|
||||||
|
# this parameter on previous minibatches (up to
|
||||||
|
# grad_norm_buffer_size minibatches)
|
||||||
|
cur_grad_norms = [] # and `cur_grad_norms` contains the squared l2
|
||||||
|
# norm norm of this step's gradient for this
|
||||||
|
# parameter, as a Tensor.
|
||||||
|
|
||||||
|
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is not None:
|
||||||
|
params_with_grad.append(p)
|
||||||
|
if p.grad.is_sparse:
|
||||||
|
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||||
|
grads.append(p.grad)
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
# Lazy state initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
|
||||||
|
# The things below are not part of original Adam, they are the Madam extension..
|
||||||
|
state['target_rms'] = _get_target_rms(p, min_target_rms)
|
||||||
|
# grad_norm_buf is a rotating buffer containing (grad_norm**2, count), where
|
||||||
|
# count is 1 for grad_norms that are set and 0 for those that are not set because
|
||||||
|
# we're near step 0 or because they were infinite.
|
||||||
|
state['grad_norm_buf'] = torch.zeros(grad_norm_buffer_size, 2, device=p.device)
|
||||||
|
|
||||||
|
exp_avgs.append(state['exp_avg'])
|
||||||
|
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||||
|
|
||||||
|
target_rms_values.append(state['target_rms'])
|
||||||
|
|
||||||
|
cur_step = state['step']
|
||||||
|
if limit_grad_factor != float('inf'):
|
||||||
|
grad_norm_buf = state['grad_norm_buf']
|
||||||
|
cur_grad_norm = (p.grad ** 2).sum() # actually squared nom
|
||||||
|
prev_mean_norm = grad_norm_buf.sum(0) # prev_mean_norm is a Tensor [ tot_norm_squared, count ]
|
||||||
|
grad_norm_buf[cur_step % grad_norm_buffer_size][0] = cur_grad_norm
|
||||||
|
grad_norm_buf[cur_step % grad_norm_buffer_size][1].fill_(1.0)
|
||||||
|
prev_norm_stats.append(prev_mean_norm)
|
||||||
|
cur_grad_norms.append(cur_grad_norm)
|
||||||
|
|
||||||
|
# update the steps for each param group update
|
||||||
|
cur_step += 1
|
||||||
|
state['step'] = cur_step
|
||||||
|
# record the step after step update
|
||||||
|
state_steps.append(cur_step)
|
||||||
|
|
||||||
|
if limit_grad_factor != float('inf'):
|
||||||
|
self._apply_grad_norm_clipping(group['params'],
|
||||||
|
prev_norm_stats, cur_grad_norms, grads,
|
||||||
|
limit_grad_factor, grad_norm_buffer_size)
|
||||||
|
|
||||||
|
_madam(params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
target_rms_values,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=group['lr'],
|
||||||
|
eps=group['eps'],
|
||||||
|
l2=group['l2'],
|
||||||
|
l2_period=group['l2_period'])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_grad_norm_clipping(self,
|
||||||
|
params_list,
|
||||||
|
prev_norm_stats: List[Tensor],
|
||||||
|
cur_grad_norms: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
limit_grad_factor: float,
|
||||||
|
grad_norm_buffer_size: int) -> None:
|
||||||
|
"""
|
||||||
|
This function applies gradient norm clipping for this parameter group if this
|
||||||
|
minibatch has substantially larger gradients in this param group than
|
||||||
|
recent minibatches. The idea is to catch cases like where an LSTM
|
||||||
|
happens to blow up in the backward pass, or some code bug causes very
|
||||||
|
large or infinite gradients on a particular minibatch; so we scale
|
||||||
|
down any very large gradients and zero infinite ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params_list: some kind of iterable or list of params in this group
|
||||||
|
prev_norm_stats: a list which, for each parameter in this group
|
||||||
|
with a grad, contains a Tensor with 2 elements each, containing
|
||||||
|
# the [sum, count] of up to `grad_norm_buffer_size`
|
||||||
|
# norms of this parameter on previous minibatches;
|
||||||
|
cur_grad_norms: a list of Tensor containing, for each parameter in this group,
|
||||||
|
the norm of this step's gradient for this parameter.
|
||||||
|
grads: List of gradients with the same order as prev_norm_stats and
|
||||||
|
cur_grad_norms
|
||||||
|
limit_grad_factor: a float >1.0 (e.g. 4.0) that dictates
|
||||||
|
how-much-larger-than-average gradients we allow before clipping.
|
||||||
|
grad_norm_buffer_size: an int that determines the rolling buffer size over which
|
||||||
|
we store gradient norms
|
||||||
|
"""
|
||||||
|
num_params = len(prev_norm_stats)
|
||||||
|
assert len(grads) == num_params
|
||||||
|
|
||||||
|
all_prev_norm_stats, all_cur_grad_norms = _to_device('cpu',
|
||||||
|
torch.stack(prev_norm_stats),
|
||||||
|
torch.stack(cur_grad_norms))
|
||||||
|
assert all_prev_norm_stats.shape == (num_params, 2)
|
||||||
|
assert all_cur_grad_norms.shape == (num_params,)
|
||||||
|
|
||||||
|
# divide totals by counts (i.e. counts of iterations were we stored
|
||||||
|
# a finite grad)
|
||||||
|
all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1]
|
||||||
|
# prev_norm and cur_norm are floats, they are actually squared norms.
|
||||||
|
prev_norm = all_prev_grad_norms.sum().item()
|
||||||
|
cur_norm = all_cur_grad_norms.sum().item()
|
||||||
|
|
||||||
|
if prev_norm - prev_norm != 0.0:
|
||||||
|
# There were zero counts; fix this by using the current grad norm
|
||||||
|
# for affected parameters, and recompute all_prev_grad_norms and
|
||||||
|
# prev_norm.
|
||||||
|
for i in range(num_params):
|
||||||
|
if all_prev_norm_stats[i][1] == 0.0:
|
||||||
|
# if count is 0 and cur norm is finite, use cur norm as our estimate
|
||||||
|
# of previous norms. This would only be useful if some but not
|
||||||
|
# all params were in this situation of having no previous estimates.
|
||||||
|
cur = all_cur_grad_norms[i]
|
||||||
|
if cur - cur == 0.0: # finite..
|
||||||
|
all_prev_norm_stats[i][0] = cur
|
||||||
|
all_prev_norm_stats[i][1] = 1.0
|
||||||
|
else:
|
||||||
|
# 0.0 is a default; likely won't matter, as if we
|
||||||
|
# get infinite `cur`, we'll abandon this minibatch.
|
||||||
|
all_prev_norm_stats[i][0] = 0.0
|
||||||
|
all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1]
|
||||||
|
prev_norm = all_prev_grad_norms.sum().item()
|
||||||
|
|
||||||
|
# Deal with infinite gradients.
|
||||||
|
if cur_norm - cur_norm != 0: # cur_norm is infinite or NaN
|
||||||
|
global inf_grad_count
|
||||||
|
logging.warning(f'Infinite gradient-norm detected (cur/prev: {cur_norm}/{prev_norm}): will '
|
||||||
|
f'zero grad ({inf_grad_count}/{inf_grad_max_count} times until dying)')
|
||||||
|
inf_grad_count += 1
|
||||||
|
if inf_grad_count >= inf_grad_max_count:
|
||||||
|
assert 0, "Reached max count of infinite gradient-norm stats"
|
||||||
|
# Zero all gradients in this group
|
||||||
|
for g in grads:
|
||||||
|
g[:] = 0.
|
||||||
|
# .. and zero the stored gradient norms in grad_norm_buf (so
|
||||||
|
# that infinities don't ruin our stats of previous batches)
|
||||||
|
for p in params_list:
|
||||||
|
if p.grad is not None:
|
||||||
|
state = self.state[p]
|
||||||
|
grad_norm_buf = state['grad_norm_buf']
|
||||||
|
# cur_step is the location where we would have written the grad_norm.
|
||||||
|
# We didn't check if it was infinity before, because we didn't want to
|
||||||
|
# incur lots of GPU->CPU transfers.
|
||||||
|
cur_step = state['step'] - 1
|
||||||
|
# Remove this 'bad' step from the buffer.
|
||||||
|
grad_norm_buf[cur_step % grad_norm_buffer_size][:] = 0.0
|
||||||
|
else:
|
||||||
|
# cur_norm is finite. Check whether we have to clip this iteration's grad.
|
||||||
|
# we always remove infinities/NaNs from the buffer, so prev_norm should not
|
||||||
|
# be infinite or NaN.
|
||||||
|
assert prev_norm - prev_norm == 0.0
|
||||||
|
# cur_norm and prev_norm are actually squared norms, so we need to
|
||||||
|
# square limit_grad_factor..
|
||||||
|
limit_grad_factor2 = limit_grad_factor ** 2
|
||||||
|
if cur_norm > prev_norm * limit_grad_factor2:
|
||||||
|
grad_factor2 = (prev_norm * limit_grad_factor2) / cur_norm
|
||||||
|
grad_factor = grad_factor2 ** 0.5
|
||||||
|
cur_norm_f, prev_norm_f, grad_factor_f = ('%.2g' % cur_norm, '%.2g' % prev_norm,
|
||||||
|
'%.2g' % grad_factor)
|
||||||
|
logging.warning(f'Gradient norm exceeds average of last {grad_norm_buffer_size} '
|
||||||
|
f'gradients times {limit_grad_factor}: cur/prev {cur_norm_f}/{prev_norm_f}: '
|
||||||
|
f'scaling it by {grad_factor_f}.')
|
||||||
|
for g in grads:
|
||||||
|
g[:] *= grad_factor
|
||||||
|
# .. and scale down the stored gradient norms in grad_norm_buf, to
|
||||||
|
# avoid the bound getting too loose too quickly.
|
||||||
|
for p in params_list:
|
||||||
|
if p.grad is not None:
|
||||||
|
state = self.state[p]
|
||||||
|
grad_norm_buf = state['grad_norm_buf']
|
||||||
|
cur_step = state['step'] - 1
|
||||||
|
# the buffer contains squared norms, so multiply by grad_factor2
|
||||||
|
grad_norm_buf[cur_step % grad_norm_buffer_size][0] *= grad_factor2
|
||||||
|
|
||||||
|
|
||||||
|
def _to_device(device, *args):
|
||||||
|
"""
|
||||||
|
Transfers a tuple of Tensors from one device to another, using a single transfer. Must have
|
||||||
|
same dtype but may have different shapes.
|
||||||
|
E.g.
|
||||||
|
(cpu_tensor_a, cpu_tensor_b) = _to_device('cpu', gpu_tensor_a, gpu_tensor_b)
|
||||||
|
"""
|
||||||
|
if device == args[0].device:
|
||||||
|
return args
|
||||||
|
else:
|
||||||
|
arg0 = args[0]
|
||||||
|
combined_src = torch.cat([ x.reshape(-1) for x in args ])
|
||||||
|
combined_dest = combined_src.to(device)
|
||||||
|
dests = []
|
||||||
|
offset = 0
|
||||||
|
for src in args:
|
||||||
|
numels = src.numel()
|
||||||
|
dests.append(combined_dest[offset:offset+numels].reshape(src.shape))
|
||||||
|
offset += numels
|
||||||
|
return tuple(dests)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _get_target_rms(x: Tensor, min_target_rms: float) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns Tensor with one element, representing a target root-mean-square
|
||||||
|
value of elements of x, that we consider "reasonable", and will use a
|
||||||
|
as a "target rms" in our modified weight-decay formula. It returns
|
||||||
|
the maximum of the current RMS of the values of x, and `min_target_rms`,
|
||||||
|
as a Tensor on the same device as x.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# `norm` is the 2-norm of x currently (and this function should be
|
||||||
|
# called right after parameter initialization)
|
||||||
|
rms = ((x ** 2).sum() / x.numel()).sqrt()
|
||||||
|
largest_dim = max(list(x.shape))
|
||||||
|
numel = x.numel()
|
||||||
|
if min_target_rms > 0.0:
|
||||||
|
rms = rms.clamp(min=min_target_rms)
|
||||||
|
if x.ndim > 1 and __name__ == '__main__': # will only be used for x.ndim > 1.
|
||||||
|
print("Target rms = ", rms) # Print this in testing only.
|
||||||
|
return rms
|
||||||
|
|
||||||
|
|
||||||
|
def _madam(params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[int],
|
||||||
|
target_rms_values: List[Tensor],
|
||||||
|
*,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: float,
|
||||||
|
eps: float,
|
||||||
|
l2: bool,
|
||||||
|
l2_period: int):
|
||||||
|
r"""This is a modification of adam() from torch's optim/_functional.py.
|
||||||
|
|
||||||
|
It has been modified to:
|
||||||
|
(i) remove the amsgrad option; this shouldn't be as necessary due to
|
||||||
|
the adaptive gradient norm clipping we have added
|
||||||
|
(ii) add our special formula for l2 regularization. This doesn't have
|
||||||
|
any tunable parameters, other than the target standard deviation
|
||||||
|
of the elements of the tensor (which is passed in as target_rms).
|
||||||
|
Args:
|
||||||
|
params: list of Tensor, containing the parameters to be optimized
|
||||||
|
grads: list of Tensor, containing the gradients corresponding to
|
||||||
|
each of the params (grads[i] should correspond to params[i].grad,
|
||||||
|
although it may have undergone gradient clipping).
|
||||||
|
exp_avgs: list of Tensor, containing tensors with the same dimensions
|
||||||
|
as params and grads, that contain the moving-averages of
|
||||||
|
`grads`.
|
||||||
|
exp_avg_sqs: list of Tensor, containing tensors with the same dimensions
|
||||||
|
as params and grads, that contain the moving-averages of
|
||||||
|
`grads ** 2`.
|
||||||
|
state_steps: list of int, containing the step for each parameter (step >= 1)
|
||||||
|
target_rms_values: list of Tensor with one element each, containing the
|
||||||
|
target root-mean-square values of each parameter tensor in `params`
|
||||||
|
l2: a bool, where if true we will activate the l2 regularization
|
||||||
|
formula.
|
||||||
|
l2_period: an integer that determines how often (i.e. every how many
|
||||||
|
minibatches) we apply the l2 update. We include a scaling factor
|
||||||
|
so that as far as possible the result will not be too sensitive
|
||||||
|
to the value of this.
|
||||||
|
|
||||||
|
beta1: decay factor for gradients, e.g. 0.9
|
||||||
|
beta2: decay factor for gradients squared, e.g. 0.999
|
||||||
|
lr: learning rate, e.g. 0.0001
|
||||||
|
eps: a small constant used to prevent division by zero, e.g. 1.0e-8
|
||||||
|
|
||||||
|
See :class:`~torch.optim.Adam` for details.
|
||||||
|
"""
|
||||||
|
assert len(params) == len(grads) == len(state_steps) == len(exp_avgs) == len(exp_avg_sqs) == len(target_rms_values)
|
||||||
|
|
||||||
|
for i, param in enumerate(params):
|
||||||
|
|
||||||
|
grad = grads[i]
|
||||||
|
|
||||||
|
exp_avg = exp_avgs[i]
|
||||||
|
exp_avg_sq = exp_avg_sqs[i]
|
||||||
|
step = state_steps[i]
|
||||||
|
target_rms = target_rms_values[i]
|
||||||
|
|
||||||
|
bias_correction1 = 1 - beta1 ** step
|
||||||
|
bias_correction2 = 1 - beta2 ** step
|
||||||
|
|
||||||
|
do_l2 = param.ndim > 1 and l2 and step % l2_period == 0
|
||||||
|
|
||||||
|
if do_l2:
|
||||||
|
# This represents just the "noise term" of the gradient, i.e. the grad minus the
|
||||||
|
# running mean. We'll later divide by denom.
|
||||||
|
cur_grad_noise = (grad - exp_avg)
|
||||||
|
|
||||||
|
# Decay the first and second moment running average coefficient
|
||||||
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
|
|
||||||
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||||
|
|
||||||
|
step_size = lr / bias_correction1
|
||||||
|
|
||||||
|
if not do_l2:
|
||||||
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
else:
|
||||||
|
# We can treat "pseudo_grad" as if it were a gradient (even though it's
|
||||||
|
# actually a gradient times a per-element learning rate). The analysis
|
||||||
|
# that we used to figure out what the l2 should be did not use the fact
|
||||||
|
# that the gradients were actually gradients, it simply analyzed it as a
|
||||||
|
# quantity that can be treated as close to zero-mean and with a certain
|
||||||
|
# structure of variance, and added to the param with the formula:
|
||||||
|
#
|
||||||
|
# param -= step_size * grad
|
||||||
|
#
|
||||||
|
# The original analysis assumed the gradients were independent from frame
|
||||||
|
# to frame; in fact these are not, but the important difference can be captured
|
||||||
|
# in a scalar `grad_scale` that expresses the scale of pseudo_grad relative
|
||||||
|
# to the independent gradients that we are effectively adding on each frame
|
||||||
|
# (but with a delay).
|
||||||
|
|
||||||
|
pseudo_grad = exp_avg / denom
|
||||||
|
cur_pseudo_grad = cur_grad_noise / denom
|
||||||
|
|
||||||
|
# grad_scale expresses the expected size of cur_pseudo_grad relative to the
|
||||||
|
# original grads if we had not done the moving-average; it is the sqrt of
|
||||||
|
# the sum of the squares of coefficients of previous gradients:
|
||||||
|
# c_n = (1-beta1) beta1^n, for
|
||||||
|
# n = 0, 1, ..
|
||||||
|
# .. plus one which is the sumsq of the coefficient of 'grad' itself in
|
||||||
|
# (grad - exp_avg).
|
||||||
|
# It is relevant that the sum of the coefficients (i.e. not squared) is 1;
|
||||||
|
# if this were not so we'd have to incorporate that into the formula for l2.
|
||||||
|
grad_scale = (((1 - beta1)**2) / (1 - beta1**2) + 1) ** 0.5
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
l2_grad = _compute_l2_grad(param, cur_pseudo_grad, target_rms,
|
||||||
|
rho=step_size, grad_scale=grad_scale,
|
||||||
|
period_scale=l2_period,
|
||||||
|
eps=eps, safe=True)
|
||||||
|
|
||||||
|
# TODO: could alternate computing l2 on only, say, odd frames, and scale it
|
||||||
|
# up by 2, to save time.
|
||||||
|
param.add_(pseudo_grad + l2_grad, alpha=-step_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _view_as_matrix(x: Tensor, dim: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a Tensor of shape (n, x.shape[dim]), where n is the product
|
||||||
|
of the sizes of the other dimensions of x. This may involve a copy,
|
||||||
|
if x cannot be reshaped in this way.
|
||||||
|
"""
|
||||||
|
ndim = x.ndim
|
||||||
|
assert ndim > 1 and dim >= 0 and dim < ndim
|
||||||
|
# Move the dim to the last position in x..
|
||||||
|
if dim != ndim - 1:
|
||||||
|
x = x.transpose(dim, ndim - 1)
|
||||||
|
return x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def _outer_product(x: Tensor, dim: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a Tensor of shape (x.shape[dim], x.shape[dim]) formed by
|
||||||
|
summing the outer products of all the vectors in x of size
|
||||||
|
`x.shape[dim]`, that we get by indexing x with all tuples of dimensions
|
||||||
|
on other axes. E.g. if x is a matrix and dim == 0, this would
|
||||||
|
be torch.matmul(x, x.transpose(0, 1)).
|
||||||
|
|
||||||
|
Note: x must have at least 2 dimensions, x.ndim >= 2.
|
||||||
|
"""
|
||||||
|
x = _view_as_matrix(x, dim)
|
||||||
|
return torch.matmul(x.transpose(0, 1), x)
|
||||||
|
|
||||||
|
def _multiply_on_dim(x: Tensor, m: Tensor, dim: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Multiplies x by the matrix m which must be of shape:
|
||||||
|
(x.shape[dim], n)), with `dim` as the dimension/axis on
|
||||||
|
x to be multiplied.
|
||||||
|
|
||||||
|
Caution: result may not have the same layout/strides as x,
|
||||||
|
although it will have the same shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor to be multiplied; must have ndim >= 2
|
||||||
|
m: Symmetric matrix to multiply x by; must have
|
||||||
|
m.shape == (x.shape[dim], x.shape[dim])
|
||||||
|
dim: Dimension of x to multiply on, with 0 <= dim < x.ndim
|
||||||
|
Return:
|
||||||
|
The matrix product, of the same shape as
|
||||||
|
x, except with the size on dimension `dim` being n.
|
||||||
|
"""
|
||||||
|
ndim = x.ndim
|
||||||
|
if dim != ndim - 1:
|
||||||
|
x = x.transpose(dim, ndim - 1)
|
||||||
|
ans = torch.matmul(x, m)
|
||||||
|
if dim != ndim - 1:
|
||||||
|
# Swap the dimensions back to what they were originally.
|
||||||
|
ans = ans.transpose(dim, ndim - 1)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _multiply_product_combined(l2: Tensor, grad: Tensor, dim: int,
|
||||||
|
need_grad_sumsq: bool):
|
||||||
|
"""
|
||||||
|
This function is an optimized version of the following code:
|
||||||
|
outer_prod = _outer_product(grad, dim)
|
||||||
|
l2 = _multiply_on_dim(l2, outer_prod, dim)
|
||||||
|
if dim == 0: # could choose any dim for this
|
||||||
|
grad_sumsq = torch.trace(outer_prod)
|
||||||
|
Args:
|
||||||
|
l2: The l2 matrix which starts out as the parameter tensor x, must have >= 2 diims
|
||||||
|
grad: The gradient tensor (or a gradient-like quantity); must
|
||||||
|
have same shape as l2.
|
||||||
|
dim: The dimension of l2 and grad that we want this to
|
||||||
|
act on, with 0 <= dim < l2.ndim. We multiply l2, on
|
||||||
|
this dim, by a symmetric quantity of shape
|
||||||
|
(l2.shape[dim], l2.shape[dim]), that is formed
|
||||||
|
by a product and sum on grad (this is a matrix
|
||||||
|
product, if there are 2 axes).
|
||||||
|
Returns:
|
||||||
|
Returns (l2, grad_sumsq), where l2 is the result of
|
||||||
|
multiplying l2 by the product mentioned above, and
|
||||||
|
grad_sumsq is either None, or a Tensor representing
|
||||||
|
the sum-of-squares of `grad`; for at least one
|
||||||
|
dim with 0 <= dim < l2.ndim, we guarantee to
|
||||||
|
return such a Tensor.
|
||||||
|
"""
|
||||||
|
grad = _view_as_matrix(grad, dim)
|
||||||
|
if grad.shape[1] <= grad.shape[0]:
|
||||||
|
# Minimize the size of the intermediate product, which will probably well reflect
|
||||||
|
# the compute time since memory access can be limiting on CUDA.a
|
||||||
|
grad_product = torch.matmul(grad.transpose(0, 1), grad)
|
||||||
|
l2 = _multiply_on_dim(l2, grad_product, dim)
|
||||||
|
if need_grad_sumsq:
|
||||||
|
grad_sumsq = torch.trace(grad_product)
|
||||||
|
else:
|
||||||
|
grad_sumsq = None
|
||||||
|
return (l2, grad_sumsq)
|
||||||
|
else:
|
||||||
|
l2 = _multiply_on_dim(l2, grad.transpose(0, 1), dim)
|
||||||
|
l2 = _multiply_on_dim(l2, grad, dim)
|
||||||
|
# This branch does not compute grad_sumsq, but we're bound to
|
||||||
|
# take the other branch on at least one occasion.
|
||||||
|
return (l2, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_l2_grad(x: Tensor, grad: Tensor, target_stddev: float, rho: float,
|
||||||
|
grad_scale: float = 1.0, period_scale: int = 1,
|
||||||
|
eps: float = 1.0e-08,
|
||||||
|
safe: bool = True) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns the l2 gradient of x, which will be added to 'grad'.
|
||||||
|
This is a more principled replacement for the typical l2 regularization
|
||||||
|
formula where we do:
|
||||||
|
grad += weight_decay * x.
|
||||||
|
(Note: this must only be called if x.ndim >= 2).
|
||||||
|
|
||||||
|
For x with 2 axes, we instead do this:
|
||||||
|
|
||||||
|
grad += (rho / (2*target_stddev**2)) * (grad grad^T) x (grad^T grad) / trace(grad^T grad),
|
||||||
|
|
||||||
|
where the implicit multiplication above refers to matrix multiplication; note, x means
|
||||||
|
the variable x. We'll have to write the justification of this, which is a little
|
||||||
|
complicated, separately; it has to do with using exactly the amount of l2 in each
|
||||||
|
subspace of each dimension of x, to to cancel out the gradient noise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: parameter to be updated. MUST HAVE x.ndim >= 2.
|
||||||
|
grad: Gradient for x on this iteration (or at least, something that
|
||||||
|
is treated like a gradient in the update formula)
|
||||||
|
target_stddev: The target standard deviation (uncentered), of elements of x.
|
||||||
|
This is our estimate of what standard deviation these elements would
|
||||||
|
have in a well-trained model; it is set by some kind of heuristic.
|
||||||
|
rho: The learning rate we are going to use, as in: x -= (grad + l2) * rho.
|
||||||
|
grad_scale: A scale whereby the caller asserts that `grad` is some
|
||||||
|
quantity that is distributed like the real
|
||||||
|
gradient times `grad_scale` (this is useful when the provided `grad`
|
||||||
|
is really a moving average gradient). Because the l2 term's magnitude
|
||||||
|
is proportional to the gradient squared, we need to divide it by the
|
||||||
|
square of grad_scale, so this function uses 1/grad_scale^2 as a scaling
|
||||||
|
factor.
|
||||||
|
period_scale: An integer scale that we use to compensate for the fact that this
|
||||||
|
weight decay is only applied periodically, once every
|
||||||
|
`period_scale` minibatches. Accordingly, we make the l2 term
|
||||||
|
that many times larger.
|
||||||
|
eps: A small constant used to avoid division by zero
|
||||||
|
safe: If true, use a safe version of the formula that checks for
|
||||||
|
'overshoot' of l2 regularization and fixes the issue (might
|
||||||
|
be an issue for models that are getting unstable or have high
|
||||||
|
learning rate)
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns l2 pseudo-gradient (term to be added to `grad`).
|
||||||
|
"""
|
||||||
|
assert x.shape == grad.shape
|
||||||
|
assert x.ndim >= 2
|
||||||
|
|
||||||
|
l2 = x
|
||||||
|
grad_sumsq = None
|
||||||
|
num_ignored_dims = 0 # for an optimization for when size=1 on some dim.
|
||||||
|
for dim in range(x.ndim):
|
||||||
|
# The code below is an optimization of the following few lines,
|
||||||
|
# which were perhaps easier to understand:
|
||||||
|
# outer_prod = _outer_product(grad, dim)
|
||||||
|
# l2 = _multiply_on_dim(l2, outer_prod, dim)
|
||||||
|
# if dim == 0: # could choose any dim for this
|
||||||
|
# grad_sumsq = torch.trace(outer_prod)
|
||||||
|
if x.shape[dim] <= 1:
|
||||||
|
num_ignored_dims += 1
|
||||||
|
continue
|
||||||
|
(l2, maybe_grad_sumsq) = _multiply_product_combined(l2, grad, dim,
|
||||||
|
grad_sumsq is None)
|
||||||
|
if maybe_grad_sumsq is not None:
|
||||||
|
grad_sumsq = maybe_grad_sumsq
|
||||||
|
if grad_sumsq is None:
|
||||||
|
# We shouldn't reach here, except if at some point we start calling this
|
||||||
|
# code for tensors with ndim <= 1, or with numel() == 1.
|
||||||
|
grad_sumsq = (grad ** 2).sum()
|
||||||
|
|
||||||
|
# l2 is the amount of l2, we'll subtract this from x, as in:
|
||||||
|
# x -= rho * (grad + l2).
|
||||||
|
|
||||||
|
factor = rho * period_scale / (2.0 * (target_stddev * grad_scale)**2)
|
||||||
|
l2 = l2 * (factor / (grad_sumsq ** (x.ndim - 1 - num_ignored_dims) + eps))
|
||||||
|
|
||||||
|
if safe and rho > 0:
|
||||||
|
#x2_sum = (x ** 2).sum()
|
||||||
|
l2_sum = (l2 ** 2).sum() * (rho * rho)
|
||||||
|
cross_sum = (x * l2).sum() * rho
|
||||||
|
alpha = cross_sum / (l2_sum + eps)
|
||||||
|
# We want to minimize the sum-of-squares of (x - alpha * rho * l2), where alpha
|
||||||
|
# is a constant in [0,1] that we are about to estimate, intended to prevent
|
||||||
|
# instability by scaling down our weight decay formula. Right now (and treating
|
||||||
|
# things as if they were scalars for brevity):
|
||||||
|
# x2_sum = x * x
|
||||||
|
# l2_sum = rho * rho * l2 * l2
|
||||||
|
# cross_sum = x * rho * l2
|
||||||
|
# We want to minimize the sum-sq of (x - alpha * rho * l2),
|
||||||
|
# i.e. we want to choose alpha to minimize:
|
||||||
|
# x2_sum - 2 * alpha * cross_sum + alpha^2 * l2_sum
|
||||||
|
# d/dalpha of this, is:
|
||||||
|
# -2*cross_sum + 2 * alpha * l2_sum
|
||||||
|
# and setting this to zero and solving for alpha, we have:
|
||||||
|
# alpha = cross_sum / l2_sum.
|
||||||
|
# If it turns out that alpha >= 1, then we just use alpha=1
|
||||||
|
# (the original formula), as there is no problem with
|
||||||
|
# instability/overshoot.
|
||||||
|
l2.mul_(alpha.clamp(max=1.0))
|
||||||
|
if random.random() < 0.001 and alpha < 1.0:
|
||||||
|
logging.info(f'madam optimizer: alpha={alpha}, shape={tuple(x.shape)}')
|
||||||
|
return l2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Moam(object):
|
||||||
|
"""
|
||||||
|
Implements Moam optimizer. This is a modified version of the Noam optimizer
|
||||||
|
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
|
||||||
|
but changed to use Madam (see above) instead of Adam as the base optimizer.
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
|
||||||
|
|
||||||
|
Caution: you probably want to set 'factor' to a smaller value than you would typically
|
||||||
|
use for a corresponding Noam optimizer, because Moam does a kind of l2 regularization which
|
||||||
|
keeps the parameters fairly small, so the relative changes in model parameters
|
||||||
|
will be larger than Noam, for any given learning rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||||
|
model_size: attention dimension of the transformer model
|
||||||
|
factor: learning rate factor, that multiplies the output of the
|
||||||
|
formula based on model size
|
||||||
|
warm_step: number of warmup steps before the learning rate starts to decrease
|
||||||
|
(it increases until this point).
|
||||||
|
min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
|
||||||
|
on the "target root-mean-square value" that is used when the initialization
|
||||||
|
of a tensor is zero or below this value. It may be worth optimizing.
|
||||||
|
Don't worry about tensors with fewer than 2 dimensions when setting this,
|
||||||
|
these are not subject to our l2 formula.
|
||||||
|
limit_grad_factor: you can set this to a finite value, e.g. 2.0, to activate
|
||||||
|
a mechanism that limits the norms of larger-than-usual gradients.
|
||||||
|
This seems to cause a slowdown, likely due to GPU->CPU transfers.
|
||||||
|
l2_period: mechanism to improve the optimization speed, by only applying the l2
|
||||||
|
regularization (which is a complicated formula) every this-many
|
||||||
|
minibatches. E.g. can set it to 2 or 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, model_size: int = 256,
|
||||||
|
factor: float = 2.0, warm_step: int = 25000,
|
||||||
|
min_target_rms: float = 0.05,
|
||||||
|
limit_grad_factor: float = float('inf'),
|
||||||
|
l2_period: int = 1) -> None:
|
||||||
|
"""Construct an Noam object."""
|
||||||
|
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
|
||||||
|
min_target_rms=min_target_rms,
|
||||||
|
limit_grad_factor=limit_grad_factor,
|
||||||
|
l2_period=l2_period)
|
||||||
|
self._step = 0
|
||||||
|
self.warmup = warm_step
|
||||||
|
self.factor = factor
|
||||||
|
self.model_size = model_size
|
||||||
|
self._rate = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
"""Return param_groups."""
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Update parameters and rate."""
|
||||||
|
self._step += 1
|
||||||
|
rate = self.rate()
|
||||||
|
for p in self.optimizer.param_groups:
|
||||||
|
p["lr"] = rate
|
||||||
|
self._rate = rate
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
def rate(self, step=None):
|
||||||
|
"""Implement `lrate` above."""
|
||||||
|
if step is None:
|
||||||
|
step = self._step
|
||||||
|
return (
|
||||||
|
self.factor
|
||||||
|
* self.model_size ** (-0.5)
|
||||||
|
* min(step ** (-0.5), step * self.warmup ** (-1.5))
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
"""Reset gradient."""
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Return state_dict."""
|
||||||
|
return {
|
||||||
|
"_step": self._step,
|
||||||
|
"warmup": self.warmup,
|
||||||
|
"factor": self.factor,
|
||||||
|
"model_size": self.model_size,
|
||||||
|
"_rate": self._rate,
|
||||||
|
"optimizer": self.optimizer.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Load state_dict."""
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key == "optimizer":
|
||||||
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||||
|
else:
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(torch.nn.Module):
|
||||||
|
"""Class for testing the Madam optimizer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(TestModel, self).__init__()
|
||||||
|
self.first_layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(100, 200),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(200, 300),
|
||||||
|
torch.nn.ReLU())
|
||||||
|
self.conv1 = torch.nn.Conv1d(in_channels=300, out_channels=200,
|
||||||
|
kernel_size=1)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
self.conv2 = torch.nn.Conv1d(in_channels=200, out_channels=250,
|
||||||
|
kernel_size=3)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# from (B, T, 100) to (B, T, 200)
|
||||||
|
x = self.first_layers(x)
|
||||||
|
# B, T, C -> B, C, T
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.conv2(self.relu(self.conv1(x)))
|
||||||
|
# B, C, T -> B, T, C
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_madam():
|
||||||
|
print("Testing Madam optimizer")
|
||||||
|
global inf_grad_max_count
|
||||||
|
inf_grad_max_count = 200
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
devices_and_l2 = [(torch.device('cuda'), True),
|
||||||
|
(torch.device('cuda'), False)]
|
||||||
|
#(torch.device('cpu'), True),
|
||||||
|
#(torch.device('cpu'), False)]
|
||||||
|
else:
|
||||||
|
devices_and_l2 = [(torch.device('cpu'), True),
|
||||||
|
(torch.device('cpu'), False)]
|
||||||
|
|
||||||
|
|
||||||
|
for (device, l2) in devices_and_l2:
|
||||||
|
model = TestModel().to(device)
|
||||||
|
# min_target_rms=0.01 is for testing, so the target equals the initial RMS
|
||||||
|
# and we can more easily tell whether our update has the desired effect.
|
||||||
|
# I also tested this with betas=(0.1, 0.98), to check that the effect of
|
||||||
|
# `grad_scale` was correct (it only makes much difference for small beta).
|
||||||
|
optimizer = Madam(model.parameters(), lr=0.0005, betas=(0.9, 0.98),
|
||||||
|
l2=l2, min_target_rms=0.01, l2_period=1)
|
||||||
|
#optimizer = torch.optim.Adam(model.parameters())
|
||||||
|
|
||||||
|
def get_elems_rms(x: Tensor) -> Tensor:
|
||||||
|
return ((x ** 2).sum() / x.numel()).sqrt().item()
|
||||||
|
|
||||||
|
for i in range(1000):
|
||||||
|
if i % 100 == 0:
|
||||||
|
rms_values = (get_elems_rms(model.first_layers[0].weight),
|
||||||
|
get_elems_rms(model.first_layers[2].weight),
|
||||||
|
get_elems_rms(model.conv1.weight),
|
||||||
|
get_elems_rms(model.conv2.weight))
|
||||||
|
print(f"Iter {i}, l2={l2}, device={device}: stddevs = {rms_values} ")
|
||||||
|
B = 4
|
||||||
|
T = 20
|
||||||
|
x = torch.randn(B, T, 100).to(device)
|
||||||
|
y = model(x)
|
||||||
|
yderiv = torch.randn_like(y)
|
||||||
|
if i % 190 <= 3 and i > 0:
|
||||||
|
yderiv *= 100.0
|
||||||
|
if i % 550 == 0 and i > 0:
|
||||||
|
yderiv *= float('inf')
|
||||||
|
|
||||||
|
y.backward(gradient=yderiv)
|
||||||
|
optimizer.step()
|
||||||
|
model.zero_grad()
|
||||||
|
print("")
|
||||||
|
|
||||||
|
def test_moam():
|
||||||
|
print("Testing Moam optimizer")
|
||||||
|
model = TestModel()
|
||||||
|
# min_target_rms=0.01 is for testing, so the target equals the initial RMS
|
||||||
|
# and we can more easily tell whether our update has the desired effect.
|
||||||
|
optimizer = Moam(model.parameters(), factor=1.0, warm_step=300,
|
||||||
|
min_target_rms=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def get_elems_rms(x: Tensor) -> Tensor:
|
||||||
|
return ((x ** 2).sum() / x.numel()).sqrt().item()
|
||||||
|
|
||||||
|
for i in range(1000):
|
||||||
|
if i % 100 == 0:
|
||||||
|
rms_values = (get_elems_rms(model.first_layers[0].weight),
|
||||||
|
get_elems_rms(model.first_layers[2].weight),
|
||||||
|
get_elems_rms(model.conv1.weight),
|
||||||
|
get_elems_rms(model.conv2.weight))
|
||||||
|
print(f"Iter {i} (Moam): stddevs = {rms_values} ")
|
||||||
|
B = 4
|
||||||
|
T = 20
|
||||||
|
x = torch.randn(B, T, 100)
|
||||||
|
y = model(x)
|
||||||
|
yderiv = torch.randn_like(y)
|
||||||
|
if i % 190 <= 3 and i > 0:
|
||||||
|
yderiv *= 100.0
|
||||||
|
if i % 550 == 0 and i > 0:
|
||||||
|
yderiv *= float('inf')
|
||||||
|
|
||||||
|
y.backward(gradient=yderiv)
|
||||||
|
optimizer.step()
|
||||||
|
model.zero_grad()
|
||||||
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_device():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
a_gpu = torch.ones(1,2,3,4, device='cuda')
|
||||||
|
b_gpu = torch.zeros(3,8, device='cuda')
|
||||||
|
(a_cpu, b_cpu) = _to_device('cpu', a_gpu, b_gpu)
|
||||||
|
print("a_cpu,b_cpu = ", a_cpu, b_cpu)
|
||||||
|
(a_gpu2, b_gpu2) = _to_device('cuda', a_cpu, b_cpu)
|
||||||
|
print("a_gpu2,b_gpu2 = ", a_gpu2, b_gpu2)
|
||||||
|
|
||||||
|
# Caution: this testing code is not very automated, it reqires looking at the output to
|
||||||
|
# make sure it looks right. The main thing is that with l2=True, the printed stddevs stay close
|
||||||
|
# to the "Target rms" values, which are printed out; while with l2=False, the stddevs
|
||||||
|
# increase to significantly higher than that.
|
||||||
|
#
|
||||||
|
# The test of the Moam optimizer is mainly to make sure it runs; the scale of the
|
||||||
|
# gradients, and the learning rate, are such that one of the rms's stays quite a bit
|
||||||
|
# above the target value, i.e. (0.047, 0.044, 0.047), vs. targets of
|
||||||
|
# (0.057, 0.04, 0.019), I think this has to do with the alpha<1 stability mechanism being
|
||||||
|
# activated, the l2 does have an effect, as I verified by changing the code to set
|
||||||
|
# l2=False.
|
||||||
|
def main():
|
||||||
|
# Set number of threads to 1, or Torch can do weird things that make it extremely slow.
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
#test_to_device()
|
||||||
|
random.seed(0)
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
test_madam()
|
||||||
|
#test_moam()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
607
egs/librispeech/ASR/conformer_lm/train.py
Executable file
607
egs/librispeech/ASR/conformer_lm/train.py
Executable file
@ -0,0 +1,607 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Daniel Povey)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import dataset # from .
|
||||||
|
import madam # from .
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
from conformer import MaskedLmConformer
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from madam import Moam
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master-port",
|
||||||
|
type=int,
|
||||||
|
default=12354,
|
||||||
|
help="Master port to use for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorboard",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Should various information be logged in tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
is saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- exp_dir: It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
|
||||||
|
- lr: It specifies the initial learning rate
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||||
|
and continue training from that checkpoint.
|
||||||
|
|
||||||
|
- num_epochs: Number of epochs to train.
|
||||||
|
|
||||||
|
- num_valid_batches: Number of batches of validation data to use each
|
||||||
|
time we compute validation loss
|
||||||
|
|
||||||
|
- symbols_per_batch: Number of symbols in each batch (sampler will
|
||||||
|
choose the number of sentences to satisfy this contraint).
|
||||||
|
|
||||||
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
|
the model that has the lowest training loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_valid_loss: Best validation loss so far. It is used to select
|
||||||
|
the model that has the lowest validation loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_train_epoch: It is the epoch that has the best training loss.
|
||||||
|
|
||||||
|
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||||
|
|
||||||
|
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||||
|
contains number of batches trained so far across
|
||||||
|
epochs.
|
||||||
|
|
||||||
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||||
|
|
||||||
|
"""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"exp_dir": Path("conformer_lm/exp_1"),
|
||||||
|
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||||
|
"num_tokens": 5000,
|
||||||
|
"blank_sym": 0,
|
||||||
|
"bos_sym": 1,
|
||||||
|
"eos_sym": 1,
|
||||||
|
"start_epoch": 0,
|
||||||
|
"num_epochs": 20,
|
||||||
|
"num_valid_batches": 100,
|
||||||
|
"symbols_per_batch": 10000,
|
||||||
|
"best_train_loss": float("inf"),
|
||||||
|
"best_valid_loss": float("inf"),
|
||||||
|
"best_train_epoch": -1,
|
||||||
|
"best_valid_epoch": -1,
|
||||||
|
"batch_idx_train": 0,
|
||||||
|
"log_interval": 10,
|
||||||
|
"reset_interval": 200,
|
||||||
|
"valid_interval": 3000,
|
||||||
|
"beam_size": 10,
|
||||||
|
"accum_grad": 1,
|
||||||
|
"attention_dim": 512,
|
||||||
|
"nhead": 8,
|
||||||
|
"num_decoder_layers": 6,
|
||||||
|
"lr_factor": 2.0,
|
||||||
|
"warm_step": 20000,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_if_available(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
|
If params.start_epoch is positive, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||||
|
|
||||||
|
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||||
|
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
The return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
optimizer:
|
||||||
|
The optimizer that we are using.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler we are using.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if params.start_epoch <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
saved_params = load_checkpoint(
|
||||||
|
filename,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"best_train_epoch",
|
||||||
|
"best_valid_epoch",
|
||||||
|
"batch_idx_train",
|
||||||
|
"best_train_loss",
|
||||||
|
"best_valid_loss",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
"""
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
|
save_checkpoint_impl(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
|
|
||||||
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
model: nn.Module,
|
||||||
|
batch: Tuple,
|
||||||
|
is_training: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Compute training or validation loss given the model and its inputs
|
||||||
|
(this corresponds to log-prob of the targets, with weighting
|
||||||
|
of 1.0 for masked subsequences
|
||||||
|
(including padding blanks), and something smaller, e.g. 0.25,
|
||||||
|
for non-masked positions (this is not totally trivial due to
|
||||||
|
a small amount of randomization of symbols).
|
||||||
|
|
||||||
|
This loss is not normalized; you can divide by batch[4].sum()
|
||||||
|
to get a normalized loss (i.e. divide by soft-count).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training. It is an instance of MaskedLmConformer in our case.
|
||||||
|
batch:
|
||||||
|
A batch of data, actually a tuple of 5 tensors (on the device), as returned
|
||||||
|
by collate_fn in ./dataset.py.
|
||||||
|
is_training:
|
||||||
|
True for training. False for validation. When it is True, this
|
||||||
|
function enables autograd during computation; when it is False, it
|
||||||
|
disables autograd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns the loss as a scalar tensor.
|
||||||
|
"""
|
||||||
|
(masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights) = batch
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(is_training):
|
||||||
|
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||||
|
tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask)
|
||||||
|
loss = (tgt_nll * tgt_weights).sum()
|
||||||
|
|
||||||
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
device: torch.device,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Run the validation process. The validation loss
|
||||||
|
is saved in `params.valid_loss`.
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tot_loss = 0.0
|
||||||
|
tot_frames = 0.0
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
batch = tuple(x.to(device) for x in batch)
|
||||||
|
|
||||||
|
# `batch` is actually a tuple.. we'll unpack it later.
|
||||||
|
loss = compute_loss(model, batch, is_training=False)
|
||||||
|
num_frames = batch[4].sum()
|
||||||
|
|
||||||
|
assert loss.requires_grad is False
|
||||||
|
assert ctc_loss.requires_grad is False
|
||||||
|
assert att_loss.requires_grad is False
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
num_frames_cpu = num_frames.cpu().item()
|
||||||
|
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_frames += num_frames_cpu
|
||||||
|
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
s = torch.tensor(
|
||||||
|
[tot_loss, tot_frames],
|
||||||
|
device=loss.device,
|
||||||
|
)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
(tot_loss, tot_frames) = s.cpu().tolist()
|
||||||
|
|
||||||
|
params.valid_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if params.valid_loss < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = params.valid_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
device: torch.device,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
|
The training loss from the mean of all frames is saved in
|
||||||
|
`params.train_loss`. It runs the validation process every
|
||||||
|
`params.valid_interval` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device:
|
||||||
|
The device to use for training (model must be on this device)
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training.
|
||||||
|
optimizer:
|
||||||
|
The optimizer we are using.
|
||||||
|
train_dl:
|
||||||
|
Dataloader for the training dataset.
|
||||||
|
valid_dl:
|
||||||
|
Dataloader for the validation dataset.
|
||||||
|
tb_writer:
|
||||||
|
Writer to write log messages to tensorboard.
|
||||||
|
world_size:
|
||||||
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||||
|
"""
|
||||||
|
model.train() # training mode
|
||||||
|
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
|
||||||
|
params.tot_loss = 0.0
|
||||||
|
params.tot_frames = 0.0
|
||||||
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
batch = tuple(x.to(device) for x in batch)
|
||||||
|
|
||||||
|
loss = compute_loss(
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward() # We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total
|
||||||
|
# gradient scale so this should not matter.
|
||||||
|
# clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
num_frames_cpu = batch[4].sum().cpu().item()
|
||||||
|
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_frames += num_frames_cpu
|
||||||
|
|
||||||
|
params.tot_frames += num_frames_cpu
|
||||||
|
params.tot_loss += loss_cpu
|
||||||
|
|
||||||
|
tot_avg_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if batch_idx % params.log_interval == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, "
|
||||||
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
|
f"batch size: {batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_loss",
|
||||||
|
loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_loss",
|
||||||
|
tot_avg_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
|
||||||
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
|
compute_validation_loss(
|
||||||
|
device=device,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, "
|
||||||
|
f"valid loss {params.valid_loss:.4f},"
|
||||||
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_loss",
|
||||||
|
params.valid_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
params.train_loss = params.tot_loss / params.tot_frames
|
||||||
|
|
||||||
|
if params.train_loss < params.best_train_loss:
|
||||||
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def run(rank, world_size, args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rank:
|
||||||
|
It is a value between 0 and `world_size-1`, which is
|
||||||
|
passed automatically by `mp.spawn()` in :func:`main`.
|
||||||
|
The node with rank 0 is responsible for saving checkpoint.
|
||||||
|
world_size:
|
||||||
|
Number of GPUs for DDP training.
|
||||||
|
args:
|
||||||
|
The return value of get_parser().parse_args()
|
||||||
|
"""
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
fix_random_seed(42)
|
||||||
|
if world_size > 1:
|
||||||
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
|
logging.info("Training started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
if args.tensorboard and rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
|
num_tokens = params.num_tokens
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = MaskedLmConformer(
|
||||||
|
num_classes=params.num_tokens,
|
||||||
|
d_model=params.attention_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
if world_size > 1:
|
||||||
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
|
optimizer = Moam(
|
||||||
|
model.parameters(),
|
||||||
|
model_size=params.attention_dim,
|
||||||
|
factor=params.lr_factor,
|
||||||
|
warm_step=params.warm_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoints:
|
||||||
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
|
train,test = dataset.load_train_test_lm_dataset(params.lm_dataset)
|
||||||
|
|
||||||
|
collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=params.bos_sym,
|
||||||
|
eos_sym=params.eos_sym,
|
||||||
|
blank_sym=params.blank_sym,
|
||||||
|
mask_proportion=0.15,
|
||||||
|
padding_proportion=0.15,
|
||||||
|
randomize_proportion=0.05,
|
||||||
|
inv_mask_length=0.25,
|
||||||
|
unmasked_weight=0.25))
|
||||||
|
|
||||||
|
train_sampler = dataset.LmBatchSampler(train,
|
||||||
|
symbols_per_batch=params.symbols_per_batch,
|
||||||
|
world_size=world_size, rank=rank)
|
||||||
|
test_sampler = dataset.LmBatchSampler(test,
|
||||||
|
symbols_per_batch=params.symbols_per_batch,
|
||||||
|
world_size=world_size, rank=rank)
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(train,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
valid_dl = torch.utils.data.DataLoader(test,
|
||||||
|
batch_sampler=test_sampler,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
cur_lr = optimizer._rate
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
train_one_epoch(
|
||||||
|
device=device,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
train_dl=train_dl,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
tb_writer=tb_writer,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_checkpoint(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user