mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Some fixes..
This commit is contained in:
parent
72f4a673b1
commit
d1f2f93460
@ -27,7 +27,7 @@ class Eve(Optimizer):
|
|||||||
r"""
|
r"""
|
||||||
Implements Eve algorithm. This is a modified version of AdamW with a special
|
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||||
rms of the parameters approach a particular specified value (generally 0.1). This is
|
rms of the parameters approach a particular specified value (we suggest 0.1). This is
|
||||||
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||||
will be close to invariant to the absolute scale on the parameter matrix.
|
will be close to invariant to the absolute scale on the parameter matrix.
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class Eve(Optimizer):
|
|||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps'])
|
||||||
|
|
||||||
step_size = group['lr'] / bias_correction1
|
step_size = group['lr'] / bias_correction1
|
||||||
target_rms = group['target_rms']
|
target_rms = group['target_rms']
|
||||||
@ -141,7 +141,7 @@ class Eve(Optimizer):
|
|||||||
# Suppose we are going to shrinkage with a small value epsilon (not the
|
# Suppose we are going to shrinkage with a small value epsilon (not the
|
||||||
# same as the eps above!), i.e. param *= (1-epsilon). Then
|
# same as the eps above!), i.e. param *= (1-epsilon). Then
|
||||||
# if E[param_elem^2] == target_rms^2,
|
# if E[param_elem^2] == target_rms^2,
|
||||||
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2),
|
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2),
|
||||||
# which we can put as:
|
# which we can put as:
|
||||||
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
|
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
|
||||||
# Setting delta_var_from_shrinkage = -delta_var_from_update
|
# Setting delta_var_from_shrinkage = -delta_var_from_update
|
||||||
@ -157,98 +157,9 @@ class Eve(Optimizer):
|
|||||||
# decay.
|
# decay.
|
||||||
|
|
||||||
# this is the weight-decay amount...
|
# this is the weight-decay amount...
|
||||||
weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2)
|
weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2)
|
||||||
|
|
||||||
p.mul_(1 - weight_decay)
|
p.mul_(1 - weight_decay)
|
||||||
p.add_(delta, alpha=-step_size)
|
p.add_(delta, alpha=-step_size)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Noam(object):
|
|
||||||
"""
|
|
||||||
Implements Noam optimizer.
|
|
||||||
|
|
||||||
Proposed in
|
|
||||||
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
|
||||||
|
|
||||||
Modified from
|
|
||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
iterable of parameters to optimize or dicts defining parameter groups
|
|
||||||
model_size:
|
|
||||||
attention dimension of the transformer model
|
|
||||||
factor:
|
|
||||||
learning rate factor
|
|
||||||
warm_step:
|
|
||||||
warmup steps
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
model_size: int = 256,
|
|
||||||
factor: float = 10.0,
|
|
||||||
warm_step: int = 25000,
|
|
||||||
weight_decay=0,
|
|
||||||
) -> None:
|
|
||||||
"""Construct an Noam object."""
|
|
||||||
self.optimizer = torch.optim.Adam(
|
|
||||||
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
|
||||||
)
|
|
||||||
self._step = 0
|
|
||||||
self.warmup = warm_step
|
|
||||||
self.factor = factor
|
|
||||||
self.model_size = model_size
|
|
||||||
self._rate = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def param_groups(self):
|
|
||||||
"""Return param_groups."""
|
|
||||||
return self.optimizer.param_groups
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
"""Update parameters and rate."""
|
|
||||||
self._step += 1
|
|
||||||
rate = self.rate()
|
|
||||||
for p in self.optimizer.param_groups:
|
|
||||||
p["lr"] = rate
|
|
||||||
self._rate = rate
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
def rate(self, step=None):
|
|
||||||
"""Implement `lrate` above."""
|
|
||||||
if step is None:
|
|
||||||
step = self._step
|
|
||||||
return (
|
|
||||||
self.factor
|
|
||||||
* self.model_size ** (-0.5)
|
|
||||||
* self.warmup ** (-0.5 - -0.333)
|
|
||||||
* min(step ** (-0.333), step * self.warmup ** (-1.333))
|
|
||||||
)
|
|
||||||
|
|
||||||
def zero_grad(self):
|
|
||||||
"""Reset gradient."""
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
"""Return state_dict."""
|
|
||||||
return {
|
|
||||||
"_step": self._step,
|
|
||||||
"warmup": self.warmup,
|
|
||||||
"factor": self.factor,
|
|
||||||
"model_size": self.model_size,
|
|
||||||
"_rate": self._rate,
|
|
||||||
"optimizer": self.optimizer.state_dict(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
"""Load state_dict."""
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
if key == "optimizer":
|
|
||||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
||||||
else:
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
@ -48,7 +48,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer, Noam
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -437,7 +437,7 @@ def save_checkpoint(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None,
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -652,7 +652,7 @@ def train_one_epoch(
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
lr_scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -848,7 +848,7 @@ def run(rank, world_size, args):
|
|||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
@ -908,12 +908,16 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
warmup = 0.0
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user