mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Remove rank-1 dims, meaning where size==numel(), from processing.
This commit is contained in:
parent
2fc9eb9789
commit
ffeef4ede4
@ -178,8 +178,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
# if size == p.numel(), is_one_axis == True above
|
||||
if size == 1 or size == p.numel():
|
||||
continue
|
||||
|
||||
|
||||
@ -320,7 +319,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
(step + lr_update_period))
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
if size == 1 or size == p.numel():
|
||||
continue
|
||||
|
||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||
@ -363,7 +362,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
if size == 1 or size == p.numel():
|
||||
continue
|
||||
param_cov = state[f"param_cov_{dim}"]
|
||||
U, S, V = _svd(param_cov)
|
||||
@ -472,7 +471,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# grad_cov_{dim}, which are for periodically updating the
|
||||
# learning-rate matrices.
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
if size == 1 or size == p.numel():
|
||||
continue
|
||||
if step % grad_cov_period == 0 or step < grad_cov_period:
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
@ -559,8 +558,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
co-ordinates); if False, the reverse. They differ by a transpose,
|
||||
not an inverse, and do not make a round trip.
|
||||
"""
|
||||
numel = x.numel()
|
||||
for dim in range(x.ndim):
|
||||
if x.shape[dim] == 1:
|
||||
size = x.shape[dim]
|
||||
if size == 1 or size == numel:
|
||||
continue
|
||||
Q = state[f"Q_{dim}"]
|
||||
if forward:
|
||||
@ -575,236 +576,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
x = x.transpose(-1, dim)
|
||||
return x
|
||||
|
||||
def _store_grad_stats(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
beta2: float) -> None:
|
||||
"""
|
||||
Accumulate some stats for each dimension of the tensor, about the covariance of gradients.
|
||||
The stats are just summed, not decayed, as the scalar factor will be normalized out later.
|
||||
"""
|
||||
ndim = grad.ndim
|
||||
kwargs = {'device': grad.device, 'dtype': grad.dtype}
|
||||
for dim in range(ndim):
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
g = grad.transpose(dim, -1)
|
||||
g = g.reshape(-1, g.shape[-1])
|
||||
# We don't care about the scalar factor, we will normalize when we
|
||||
# use it.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(g.t(), g))
|
||||
|
||||
def _estimate_projections(self,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
param_eps: float,
|
||||
param_rel_eps: float,
|
||||
param_rel_max: float,
|
||||
param_reverse_cutoff: float,
|
||||
beta3: float,
|
||||
param_pow: float,
|
||||
grad_pow: float,
|
||||
grad_min_rand: float,
|
||||
) -> None:
|
||||
"""
|
||||
Estimate the per-dimension projections proj_0, proj_1 and so on.
|
||||
These projections, when applied with forward==True by _change_coordinates(),
|
||||
first rotate into a space
|
||||
where the optimal learning-rate-scale is diagonalized (see self._estimate_proj);
|
||||
then scale by the parameter scale. When applied with forward == False
|
||||
they first scale by the parameter scale and then rotate back into canonical
|
||||
co-ordinates. (Thus, applying with forward==True and then forward==False
|
||||
is not a round trip).
|
||||
|
||||
Args:
|
||||
p: The parameter for which we are etimating the projections. Supplied
|
||||
because we need this to estimate parameter covariance matrices in
|
||||
each of its tensor directions.
|
||||
state: The state dict for this parameter tensor
|
||||
param_eps: Small epsilon representing a minimum parameter magnitude;
|
||||
once the estimated parameter rms for some element of p gets below
|
||||
this value, we'll treat it as having this value.
|
||||
param_rel_eps: Another constraint on minimum parameter rms, relative to
|
||||
sqrt(overall_param_rms), where overall_param_rms is the sqrt of the
|
||||
parameter variance averaged over the entire tensor.
|
||||
beta3: discounting parameter for param_cov, applied once every estimate_period.
|
||||
param_pow: param_pow==1.0 means fully normalizing for parameter scale;
|
||||
setting to a smaller value closer to 0.0 means partial normalization.
|
||||
grad_min_rand: minimum proportion of random tensor to mix with the
|
||||
gradient covariance matrix.
|
||||
"""
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
|
||||
ndim = p.ndim
|
||||
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
proj = state[f"proj_{dim}"]
|
||||
|
||||
if proj.ndim == 2:
|
||||
# This dimension gets the full-covariance treatment.
|
||||
count = state[f"grad_cov_count_{dim}"]
|
||||
assert count != 0
|
||||
grad_cov = state[f"grad_cov_{dim}"] / count
|
||||
del state[f"grad_cov_{dim}"] # save memory
|
||||
|
||||
#self._randomize_lowrank_cov(grad_cov, count, p.shape,
|
||||
# grad_min_rand)
|
||||
|
||||
cur_param_cov = self._get_param_cov(p, dim)
|
||||
cur_param_cov = cur_param_cov / (cur_param_cov.diag().mean() + 1.0e-20) # normalize size..
|
||||
param_cov = state[f"param_cov_{dim}"]
|
||||
param_cov.mul_(beta3).add_(cur_param_cov, alpha=(1-beta3))
|
||||
|
||||
# We want the orthogonal matrix U that diagonalizes P, where
|
||||
# P is the SPD matrix such that P G^{param_pow} P^T == C^{param_pow},
|
||||
# where G == grad_cov and C == param_cov.
|
||||
# .. this is the same as the U that diagonalizes a different P
|
||||
# such that P G P^T == C^{param_pow/(2-param_pow)},
|
||||
# since the P's are related by taking-to-a-power.
|
||||
|
||||
num_samples = p.numel() // size
|
||||
reverse_cutoff = (param_reverse_cutoff if num_samples > size//4 else 1.0e+10)
|
||||
P = self._estimate_proj(grad_cov,
|
||||
param_cov,
|
||||
param_pow / grad_pow,
|
||||
param_rel_eps,
|
||||
param_rel_max,
|
||||
reverse_cutoff)
|
||||
|
||||
|
||||
# The only thing we want from P is the basis that diagonalizes
|
||||
# it, i.e. if we do the symmetric svd P = U S U^T, we can
|
||||
# interpret the shape of U as (canonical_coordinate,
|
||||
# diagonalized_coordinate)
|
||||
U, S, _ = P.svd()
|
||||
# `proj` will be of shape (diagonalized_coordinate, canonical_coordinate)
|
||||
# Later we will modify `proj` to incorporate the parameter scale.
|
||||
proj[:] = U.t()
|
||||
else:
|
||||
# This dimension will be treated diagonally. For now just set `proj` to
|
||||
# all ones, later we will modify it in _estimate_params_scale()
|
||||
proj.fill_(1.0)
|
||||
|
||||
self._estimate_param_scale(p, state, param_eps,
|
||||
param_pow, param_rel_eps, param_rel_max,
|
||||
param_reverse_cutoff)
|
||||
|
||||
|
||||
|
||||
def _multiply_by_lr(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
Multiply the grad by a positive semi-definite matrix for each dimension,
|
||||
interpreted as a non-scalar learning-rate factor.
|
||||
"""
|
||||
cur_grad = grad
|
||||
for dim in range(grad.ndim):
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
M = state[f"lr_{dim}"]
|
||||
if M.ndim == 1:
|
||||
# We are treating this dimension diagonally because it is too big.
|
||||
M = M.reshape(size, *([1] * (ndim-1-dim)))
|
||||
cur_grad = cur_grad * M
|
||||
else:
|
||||
cur_grad = self._multiply_on_dim(cur_grad, M, dim)
|
||||
return cur_grad
|
||||
|
||||
|
||||
def _multiply_by_grad_cov(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
Multiply the grad by a positive semi-definite matrix for each dimension,
|
||||
interpreted as a non-scalar learning-rate factor.
|
||||
"""
|
||||
cur_grad = grad
|
||||
for dim in range(grad.ndim):
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
M = state[f"grad_cov_{dim}"]
|
||||
if M.ndim == 1:
|
||||
# We are treating this dimension diagonally because it is too big.
|
||||
M = M.reshape(size, *([1] * (ndim-1-dim)))
|
||||
# Dividing by (M.mean() + 1.0e-20) is just arbitrary normalization,
|
||||
# to avoid getting very large or very small outputs.
|
||||
cur_grad = cur_grad * (M / (M.mean() + 1.0e-20))
|
||||
else:
|
||||
# Dividing by (M.diag().mean() + 1.0e-20) is just arbitrary normalization,
|
||||
# to avoid getting very large or very small outputs. We only need the result
|
||||
# up to an arbitrary scalar factor.
|
||||
cur_grad = self._multiply_on_dim(cur_grad, M, dim) / (M.diag().mean() + 1.0e-20)
|
||||
return cur_grad
|
||||
|
||||
|
||||
|
||||
def _change_coordinates(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
Multiply the grad by a matrix for each dimension, interpreted as a non-orthogonal
|
||||
basis change.
|
||||
If forward == True, `grad` is interpreted as: gradient -> gradient in new basis;
|
||||
if forward == False, it is interpreted as:
|
||||
parameter-change in new basis -> parameter-change in canonical basis.
|
||||
These differ by a transpose (this is not the same as inversion as the matrix is not
|
||||
orthogonal). Therefore the co-ordinate change applied both forward and backward
|
||||
does not represent a "round trip".
|
||||
|
||||
We know at this point that `grad` has more than one non-trivial dimension/axis
|
||||
(i.e. >1 dimension with size != 1).
|
||||
|
||||
Args:
|
||||
p: the Parameter that the grad is for; may be needed if we re-estimating
|
||||
the learning-rate matrix
|
||||
grad: the gradient to precondition (will not be modified by this function)
|
||||
state: the state-dict where will look up the projections.
|
||||
Returns:
|
||||
The co-ordinate-changed grad or parameter change
|
||||
"""
|
||||
cur_grad = grad
|
||||
ndim = grad.ndim
|
||||
step = state["step"]
|
||||
for dim in range(ndim):
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
proj = state[f"proj_{dim}"]
|
||||
|
||||
if proj.ndim == 1:
|
||||
# We are treating this dimension diagonally because it is too big.
|
||||
proj = proj.reshape(size, *([1] * (ndim-1-dim)))
|
||||
cur_grad = cur_grad * proj
|
||||
else:
|
||||
cur_grad = self._multiply_on_dim(cur_grad, proj, dim, forward)
|
||||
return cur_grad
|
||||
|
||||
|
||||
|
||||
def _multiply_on_dim(self, x: Tensor, M: Tensor, dim: int) -> Tensor:
|
||||
"""
|
||||
Matrix-multiplies x by symmetric matrix `M` on x's dimension numbered `dim`
|
||||
Args:
|
||||
x: Tensor to multiply, of arbitrary shape
|
||||
M: symmetric matrix to multiply x by, of shape
|
||||
(x.shape[dim], x.shape[dim]).
|
||||
"""
|
||||
# Note: can use either M or M.t() here because M is symmetric
|
||||
return torch.matmul(x.transpose(-1, dim), M.t())
|
||||
|
||||
|
||||
|
||||
def _smooth_param_rms(self,
|
||||
group: dict,
|
||||
rms: Tensor,
|
||||
@ -836,176 +607,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return rms / new_mean
|
||||
|
||||
|
||||
def _estimate_proj(self,
|
||||
grad_cov: Tensor,
|
||||
param_cov: Tensor,
|
||||
param_pow: float,
|
||||
param_rel_eps: float,
|
||||
param_rel_max: float,
|
||||
param_reverse_cutoff: float) -> Tensor:
|
||||
"""
|
||||
Return a symmetric positive definite matrix P such that
|
||||
(P grad_cov P^T == param_cov^{param_pow}),
|
||||
i.e. a descent direction that makes the covariance of steps look
|
||||
like the covariance of parameter. This will be normalizes so
|
||||
that the mean of its diagonal is 1.0 (since we'll have a separate
|
||||
such projection per dim of the tensor, and we don't want to count
|
||||
the scalar factor many times).
|
||||
|
||||
Args:
|
||||
grad_cov: Covariance of gradients, a symmetric-positive-definite
|
||||
matrix [except for roundoff!] of shape (size, size)
|
||||
param_cov: Covariance of parameters, a symmetric-positive-definite
|
||||
matrix [except for roundoff!] of shape (size, size)
|
||||
param_pow: Power that we take "param_cov" to, between 0 and 1.
|
||||
Normally it will be 1, and not all the comments mention it.
|
||||
Returns:
|
||||
A matrix P such that (alpha P grad_cov P^T == param_cov^param_pow)
|
||||
"""
|
||||
# symmetrize grad_cov and param_cov. The projection P only cares about the
|
||||
# relative sizes, so doubling both of them does not matter.
|
||||
G = grad_cov + grad_cov.t()
|
||||
C = param_cov + param_cov.t()
|
||||
U, S, V = C.svd()
|
||||
|
||||
# Using G == grad_cov and C == param_cov, and not writing transpose because
|
||||
# all these quantities are symmetric, we can use:
|
||||
#
|
||||
# P = C^{0.5} (C^{0.5} G C^{0.5})^{-0.5} C^{0.5}
|
||||
#
|
||||
# You can verify fairly easily that P G P == C: multiply on left and right
|
||||
# by C^{-0.5} on each side and you can see that both sides equal I.
|
||||
#
|
||||
# We can reduce the amount of roundoff by, after decomposing C
|
||||
# with SVD as (U S U^T), writing C^{0.5} = U Q = Q^T U^T, with Q = S^0.5 U^T.
|
||||
#
|
||||
# Then we can write P as follows:
|
||||
# P = Q^T U^T (U Q G Q^T U^T)^{-0.5} U Q,
|
||||
# and we easily show that for orthogonal U, (U X U^T)^{-0.5} = U X^{-0.5} U^T, so we
|
||||
# can do this and cancel U^T U = U U^T = I, giving:
|
||||
# P = Q^T (Q G Q^T)^{-0.5} Q.
|
||||
|
||||
# because C is symmetric, C == U S U^T, we can ignore V.
|
||||
# S_sqrt is S.sqrt() in the limit where param_pow == 1.0,
|
||||
# param_rel_eps=0, param_rel_max=inf
|
||||
|
||||
S_smoothed = self._smooth_param_diag_var(S, param_pow,
|
||||
param_rel_eps, param_rel_max,
|
||||
param_reverse_cutoff)
|
||||
S_sqrt = S_smoothed ** 0.5
|
||||
Q = (U * S_sqrt).t()
|
||||
|
||||
X = torch.matmul(Q, torch.matmul(G, Q.t()))
|
||||
|
||||
U2, S2, _ = X.svd()
|
||||
# Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have
|
||||
# P = Q^T U2 S2^{-0.5} U2^T Q
|
||||
# = Y Y^T, where
|
||||
# Y = Q^T U2 S2^{-0.25}
|
||||
|
||||
S2_pow = (S2 + 1.0e-35) ** -0.25
|
||||
Y = torch.matmul(Q.t(), U2) * S2_pow
|
||||
|
||||
P = torch.matmul(Y, Y.t())
|
||||
|
||||
if random.random() < 1.0: #0.025:
|
||||
# TODO: remove this testing code.
|
||||
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
|
||||
try:
|
||||
P = 0.5 * (P + P.t())
|
||||
_,s,_ = P.svd()
|
||||
logging.info(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||
except:
|
||||
pass
|
||||
# testing... note, this is only true modulo "eps"
|
||||
C_check = torch.matmul(torch.matmul(P, G), P)
|
||||
C_smoothed = torch.matmul(Q.t(), Q)
|
||||
# C_check should equal C_smoothed
|
||||
C_diff = C_check - C_smoothed
|
||||
# Roundoff can cause significant differences, so use a fairly large
|
||||
# threshold of 0.001. We may increase this later or even remove the check.
|
||||
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
|
||||
logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
|
||||
|
||||
return P
|
||||
|
||||
def reset_speedup(self):
|
||||
"""
|
||||
Reset the step_period so that we'll do a step each time, until the next
|
||||
time
|
||||
"""
|
||||
for group in self.param_groups:
|
||||
group["speedup_step"] = -group["speedup_first_step"]
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
state["step_period"] = 1
|
||||
|
||||
def _recalibrate_speedup(self, group: dict):
|
||||
param_prods = []
|
||||
speedup = group["lr_for_speedup"] / group["lr"]
|
||||
if speedup > 10.0:
|
||||
speedup = 10.0
|
||||
if speedup <= 1.0:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
state["step_period"] = 1
|
||||
return
|
||||
|
||||
|
||||
_, beta2, _ = group["betas"]
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
# We ignore the part of the parameter change that comes from the scale
|
||||
# (see 'scale_deriv' in the code)
|
||||
state = self.state[p]
|
||||
step = state["step"]
|
||||
|
||||
# the sqrt() is just a heuristic, it seems to give periods that work better.
|
||||
param_prods.append(state["grad_times_step"].sqrt())
|
||||
|
||||
|
||||
param_prods = torch.stack(param_prods).to('cpu')
|
||||
# TEMP
|
||||
param_prods = param_prods.sqrt()
|
||||
avg_update_freq = 1.0 / speedup
|
||||
param_freqs = param_prods * (avg_update_freq / param_prods.mean())
|
||||
|
||||
# Don't delay more than 1 / (3*speedup) steps for any parameter, to
|
||||
# ensure that no parameter waits super-long for an update. Note: with
|
||||
# beta=0.1, this would mean an average delay of about 30*speedup steps
|
||||
# before the gradient makes it to the parameter.
|
||||
min_freq = 1.0 / (3 * speedup)
|
||||
|
||||
# get the mean after applying limits of [min_freq..1] for the frequencies of
|
||||
# update
|
||||
param_freqs_mean = param_freqs.clamp(min=min_freq, max=1.0).mean()
|
||||
# renormalize after applying min and max limits
|
||||
param_freqs = param_freqs * (avg_update_freq / param_freqs_mean)
|
||||
# ...and apply the limits again.
|
||||
param_freqs = param_freqs.clamp(min=min_freq, max=1.0)
|
||||
param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down
|
||||
|
||||
actual_speedup = 1.0 / (1.0 / param_periods).mean()
|
||||
|
||||
param_prods = param_prods.tolist()
|
||||
param_freqs = param_freqs.tolist()
|
||||
param_periods = param_periods.tolist()
|
||||
|
||||
logging.info(f"PrAdam._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
|
||||
print_info = random.random() < 0.01
|
||||
i = 0
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
state["step_period"] = param_periods[i]
|
||||
if print_info:
|
||||
logging.info(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
|
||||
i += 1
|
||||
|
||||
# Do nothing more, as yet.
|
||||
|
||||
|
||||
def _svd(x: Tensor):
|
||||
|
Loading…
x
Reference in New Issue
Block a user