Some fixes..

This commit is contained in:
Daniel Povey 2022-04-04 22:40:18 +08:00
parent 72f4a673b1
commit d1f2f93460
2 changed files with 12 additions and 97 deletions

View File

@ -27,7 +27,7 @@ class Eve(Optimizer):
r""" r"""
Implements Eve algorithm. This is a modified version of AdamW with a special Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular specified value (generally 0.1). This is rms of the parameters approach a particular specified value (we suggest 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix. will be close to invariant to the absolute scale on the parameter matrix.
@ -120,7 +120,7 @@ class Eve(Optimizer):
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps'])
step_size = group['lr'] / bias_correction1 step_size = group['lr'] / bias_correction1
target_rms = group['target_rms'] target_rms = group['target_rms']
@ -157,98 +157,9 @@ class Eve(Optimizer):
# decay. # decay.
# this is the weight-decay amount... # this is the weight-decay amount...
weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2) weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2)
p.mul_(1 - weight_decay) p.mul_(1 - weight_decay)
p.add_(delta, alpha=-step_size) p.add_(delta, alpha=-step_size)
return loss return loss
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* self.warmup ** (-0.5 - -0.333)
* min(step ** (-0.333), step * self.warmup ** (-1.333))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -48,7 +48,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer, Noam from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -437,7 +437,7 @@ def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
@ -652,7 +652,7 @@ def train_one_epoch(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
lr_scheduler.step() scheduler.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
@ -848,7 +848,7 @@ def run(rank, world_size, args):
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train "train/learning_rate", cur_lr, params.batch_idx_train
@ -908,12 +908,16 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup = 0.0
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()