Some fixes..

This commit is contained in:
Daniel Povey 2022-04-04 22:40:18 +08:00
parent 72f4a673b1
commit d1f2f93460
2 changed files with 12 additions and 97 deletions

View File

@ -27,7 +27,7 @@ class Eve(Optimizer):
r"""
Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular specified value (generally 0.1). This is
rms of the parameters approach a particular specified value (we suggest 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix.
@ -120,7 +120,7 @@ class Eve(Optimizer):
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
target_rms = group['target_rms']
@ -141,7 +141,7 @@ class Eve(Optimizer):
# Suppose we are going to shrinkage with a small value epsilon (not the
# same as the eps above!), i.e. param *= (1-epsilon). Then
# if E[param_elem^2] == target_rms^2,
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2),
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2),
# which we can put as:
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
# Setting delta_var_from_shrinkage = -delta_var_from_update
@ -157,98 +157,9 @@ class Eve(Optimizer):
# decay.
# this is the weight-decay amount...
weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2)
weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2)
p.mul_(1 - weight_decay)
p.add_(delta, alpha=-step_size)
return loss
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* self.warmup ** (-0.5 - -0.333)
* min(step ** (-0.333), step * self.warmup ** (-1.333))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -48,7 +48,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer, Noam
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -437,7 +437,7 @@ def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
@ -652,7 +652,7 @@ def train_one_epoch(
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
scheduler.step()
if params.print_diagnostics and batch_idx == 5:
return
@ -848,7 +848,7 @@ def run(rank, world_size, args):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
@ -908,12 +908,16 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup = 0.0
)
loss.backward()
optimizer.step()