mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Some fixes..
This commit is contained in:
parent
72f4a673b1
commit
d1f2f93460
@ -27,7 +27,7 @@ class Eve(Optimizer):
|
||||
r"""
|
||||
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||
rms of the parameters approach a particular specified value (generally 0.1). This is
|
||||
rms of the parameters approach a particular specified value (we suggest 0.1). This is
|
||||
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||
will be close to invariant to the absolute scale on the parameter matrix.
|
||||
|
||||
@ -120,7 +120,7 @@ class Eve(Optimizer):
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps'])
|
||||
|
||||
step_size = group['lr'] / bias_correction1
|
||||
target_rms = group['target_rms']
|
||||
@ -141,7 +141,7 @@ class Eve(Optimizer):
|
||||
# Suppose we are going to shrinkage with a small value epsilon (not the
|
||||
# same as the eps above!), i.e. param *= (1-epsilon). Then
|
||||
# if E[param_elem^2] == target_rms^2,
|
||||
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2),
|
||||
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2),
|
||||
# which we can put as:
|
||||
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
|
||||
# Setting delta_var_from_shrinkage = -delta_var_from_update
|
||||
@ -157,98 +157,9 @@ class Eve(Optimizer):
|
||||
# decay.
|
||||
|
||||
# this is the weight-decay amount...
|
||||
weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2)
|
||||
weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2)
|
||||
|
||||
p.mul_(1 - weight_decay)
|
||||
p.add_(delta, alpha=-step_size)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
class Noam(object):
|
||||
"""
|
||||
Implements Noam optimizer.
|
||||
|
||||
Proposed in
|
||||
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
||||
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
||||
|
||||
Args:
|
||||
params:
|
||||
iterable of parameters to optimize or dicts defining parameter groups
|
||||
model_size:
|
||||
attention dimension of the transformer model
|
||||
factor:
|
||||
learning rate factor
|
||||
warm_step:
|
||||
warmup steps
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
model_size: int = 256,
|
||||
factor: float = 10.0,
|
||||
warm_step: int = 25000,
|
||||
weight_decay=0,
|
||||
) -> None:
|
||||
"""Construct an Noam object."""
|
||||
self.optimizer = torch.optim.Adam(
|
||||
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
||||
)
|
||||
self._step = 0
|
||||
self.warmup = warm_step
|
||||
self.factor = factor
|
||||
self.model_size = model_size
|
||||
self._rate = 0
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
"""Return param_groups."""
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def step(self):
|
||||
"""Update parameters and rate."""
|
||||
self._step += 1
|
||||
rate = self.rate()
|
||||
for p in self.optimizer.param_groups:
|
||||
p["lr"] = rate
|
||||
self._rate = rate
|
||||
self.optimizer.step()
|
||||
|
||||
def rate(self, step=None):
|
||||
"""Implement `lrate` above."""
|
||||
if step is None:
|
||||
step = self._step
|
||||
return (
|
||||
self.factor
|
||||
* self.model_size ** (-0.5)
|
||||
* self.warmup ** (-0.5 - -0.333)
|
||||
* min(step ** (-0.333), step * self.warmup ** (-1.333))
|
||||
)
|
||||
|
||||
def zero_grad(self):
|
||||
"""Reset gradient."""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def state_dict(self):
|
||||
"""Return state_dict."""
|
||||
return {
|
||||
"_step": self._step,
|
||||
"warmup": self.warmup,
|
||||
"factor": self.factor,
|
||||
"model_size": self.model_size,
|
||||
"_rate": self._rate,
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load state_dict."""
|
||||
for key, value in state_dict.items():
|
||||
if key == "optimizer":
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
@ -48,7 +48,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer, Noam
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -437,7 +437,7 @@ def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
@ -652,7 +652,7 @@ def train_one_epoch(
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
scheduler.step()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
@ -848,7 +848,7 @@ def run(rank, world_size, args):
|
||||
fix_random_seed(params.seed + epoch)
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
@ -908,12 +908,16 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup = 0.0
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user