mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Slight refactoring, preparing for batching.
This commit is contained in:
parent
d9a6180ae0
commit
d25df4af5e
@ -21,7 +21,6 @@ import torch
|
|||||||
import random
|
import random
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from icefall import diagnostics # only for testing code
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
class PrAdam(Optimizer):
|
class PrAdam(Optimizer):
|
||||||
@ -120,15 +119,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
lr = group["lr"]
|
|
||||||
size_lr = lr * group["size_lr_scale"]
|
|
||||||
beta1, beta2 = group["betas"]
|
|
||||||
scalar_max = group["scalar_max"]
|
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
size_update_period = group["size_update_period"]
|
size_update_period = group["size_update_period"]
|
||||||
param_min_rms = group["param_min_rms"]
|
|
||||||
param_max_rms = group["param_max_rms"]
|
|
||||||
lr_update_period = group["lr_update_period"]
|
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@ -207,43 +199,59 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# instead of just using a temporary and smoothing the scalar factor.
|
# instead of just using a temporary and smoothing the scalar factor.
|
||||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||||
|
|
||||||
|
self._step_one_param(group, p, state)
|
||||||
step = state["step"]
|
|
||||||
delta = state["delta"]
|
|
||||||
delta.mul_(beta1)
|
|
||||||
numel = p.numel()
|
|
||||||
if numel > 1:
|
|
||||||
# Update the size/scale of p, and set param_rms
|
|
||||||
scale_grads = state["scale_grads"]
|
|
||||||
scale_grads[step % size_update_period] = (p * grad).sum()
|
|
||||||
if step % size_update_period == size_update_period - 1:
|
|
||||||
# this learns the overall scale on the parameter, by shrinking or
|
|
||||||
# expanding it.
|
|
||||||
param_rms = state["param_rms"]
|
|
||||||
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
|
|
||||||
if step > 0:
|
|
||||||
self._size_update(p, state,
|
|
||||||
scale_grads, param_rms,
|
|
||||||
beta1, beta2, step, size_lr,
|
|
||||||
param_min_rms, param_max_rms)
|
|
||||||
|
|
||||||
|
|
||||||
if numel == 1:
|
|
||||||
# For parameters with very few elements we just use a form
|
|
||||||
# of Adam with a scale factor to reflect the overall
|
|
||||||
# parameter rms. Updates delta.
|
|
||||||
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state)
|
|
||||||
else:
|
|
||||||
if step % lr_update_period == 0 and step > 0:
|
|
||||||
self._accum_param_covs(group, p, state)
|
|
||||||
self._update_lrs(group, p, state)
|
|
||||||
self._zero_exp_avg_sq(state)
|
|
||||||
self._step(group, p, grad, state)
|
|
||||||
p.add_(delta)
|
|
||||||
state["step"] = step + 1
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def _step_one_param(self,
|
||||||
|
group: dict,
|
||||||
|
p: Tensor,
|
||||||
|
state: dict):
|
||||||
|
lr = group["lr"]
|
||||||
|
size_lr = lr * group["size_lr_scale"]
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
scalar_max = group["scalar_max"]
|
||||||
|
eps = group["eps"]
|
||||||
|
size_update_period = group["size_update_period"]
|
||||||
|
param_min_rms = group["param_min_rms"]
|
||||||
|
param_max_rms = group["param_max_rms"]
|
||||||
|
lr_update_period = group["lr_update_period"]
|
||||||
|
|
||||||
|
grad = p.grad
|
||||||
|
step = state["step"]
|
||||||
|
delta = state["delta"]
|
||||||
|
delta.mul_(beta1)
|
||||||
|
numel = p.numel()
|
||||||
|
if numel > 1:
|
||||||
|
# Update the size/scale of p, and set param_rms
|
||||||
|
scale_grads = state["scale_grads"]
|
||||||
|
scale_grads[step % size_update_period] = (p * grad).sum()
|
||||||
|
if step % size_update_period == size_update_period - 1:
|
||||||
|
# this learns the overall scale on the parameter, by shrinking or
|
||||||
|
# expanding it.
|
||||||
|
param_rms = state["param_rms"]
|
||||||
|
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
|
||||||
|
if step > 0:
|
||||||
|
self._size_update(p, state,
|
||||||
|
scale_grads, param_rms,
|
||||||
|
beta1, beta2, step, size_lr,
|
||||||
|
param_min_rms, param_max_rms)
|
||||||
|
|
||||||
|
if numel == 1:
|
||||||
|
# For parameters with very few elements we just use a form
|
||||||
|
# of Adam with a scale factor to reflect the overall
|
||||||
|
# parameter rms. Updates delta.
|
||||||
|
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state)
|
||||||
|
else:
|
||||||
|
if step % lr_update_period == 0 and step > 0:
|
||||||
|
self._accum_param_covs(group, p, state)
|
||||||
|
self._update_lrs(group, p, state)
|
||||||
|
self._zero_exp_avg_sq(state)
|
||||||
|
self._step(group, p, grad, state)
|
||||||
|
p.add_(delta)
|
||||||
|
state["step"] = step + 1
|
||||||
|
|
||||||
|
|
||||||
def _size_update(self,
|
def _size_update(self,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict,
|
state: dict,
|
||||||
@ -1343,7 +1351,8 @@ def _test_eve_cain():
|
|||||||
B = 4
|
B = 4
|
||||||
T = 2
|
T = 2
|
||||||
logging.info("in test_eve_cain")
|
logging.info("in test_eve_cain")
|
||||||
device = torch.device('cuda')
|
#device = torch.device('cuda')
|
||||||
|
device = torch.device('cpu')
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
@ -1376,11 +1385,11 @@ def _test_eve_cain():
|
|||||||
#if epoch == 100 and iter in [2,3]:
|
#if epoch == 100 and iter in [2,3]:
|
||||||
# optim.reset_speedup() # check it doesn't crash.
|
# optim.reset_speedup() # check it doesn't crash.
|
||||||
|
|
||||||
if epoch == 130:
|
#if epoch == 130:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
# opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
# 2 ** 22
|
||||||
) # allow 4 megabytes per sub-module
|
# ) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(m, opts)
|
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||||
|
|
||||||
|
|
||||||
for n, (x,y) in enumerate(train_pairs):
|
for n, (x,y) in enumerate(train_pairs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user