Slight refactoring, preparing for batching.

This commit is contained in:
Daniel Povey 2022-07-09 22:24:36 -07:00
parent d9a6180ae0
commit d25df4af5e

View File

@ -21,7 +21,6 @@ import torch
import random
from torch import Tensor
from torch.optim import Optimizer
from icefall import diagnostics # only for testing code
import logging
class PrAdam(Optimizer):
@ -120,15 +119,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
loss = closure()
for group in self.param_groups:
lr = group["lr"]
size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
size_update_period = group["size_update_period"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
lr_update_period = group["lr_update_period"]
for p in group["params"]:
if p.grad is None:
@ -207,7 +199,25 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
self._step_one_param(group, p, state)
return loss
def _step_one_param(self,
group: dict,
p: Tensor,
state: dict):
lr = group["lr"]
size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
size_update_period = group["size_update_period"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
lr_update_period = group["lr_update_period"]
grad = p.grad
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
@ -227,7 +237,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
beta1, beta2, step, size_lr,
param_min_rms, param_max_rms)
if numel == 1:
# For parameters with very few elements we just use a form
# of Adam with a scale factor to reflect the overall
@ -242,7 +251,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
p.add_(delta)
state["step"] = step + 1
return loss
def _size_update(self,
p: Tensor,
@ -1343,7 +1351,8 @@ def _test_eve_cain():
B = 4
T = 2
logging.info("in test_eve_cain")
device = torch.device('cuda')
#device = torch.device('cuda')
device = torch.device('cpu')
dtype = torch.float32
fix_random_seed(42)
@ -1376,11 +1385,11 @@ def _test_eve_cain():
#if epoch == 100 and iter in [2,3]:
# optim.reset_speedup() # check it doesn't crash.
if epoch == 130:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(m, opts)
#if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)
for n, (x,y) in enumerate(train_pairs):