Applying max to G with noinv method with metric.

This commit is contained in:
Daniel Povey 2022-07-31 02:10:27 -07:00
parent 2042c9862c
commit d84a2e22e3

View File

@ -163,7 +163,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02,
betas=(0.9, 0.98),
size_lr_scale=0.1,
cov_min=(0.025, 0.0025, 0.02, 0.0001),
cov_min=(0.025, 0.0, 0.02, 0.0001),
cov_max=(10.0, 80.0, 5.0, 400.0),
cov_pow=(1.0, 1.0, 1.0, 1.0),
param_rms_smooth0=0.4,
@ -669,7 +669,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
if True:
_,S,_ = _svd(P)
logging.info(f"Eigs of P are {S}")
logging.info(f"Eigs of P are {S[0][0]}")
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
@ -759,49 +759,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
P_diag += smooth * P_diag_mean
G = G.clone()
G_diag = _diag(G) # aliased
G_diag *= 1.005 # improve its condition, for numerical reasons.
#G = G.clone()
#G_diag = _diag(G) # aliased
G = self._smooth_cov(G,
group["cov_min"][3],
group["cov_max"][3],
group["cov_pow"][3])
# C C^T == G.
C = G.cholesky()
P_orig = P.clone()
# treat the last dim of C as being in an arbitrary space, its next-to-last dim
# is the "canonical" one that we need to sum with the dims of P.
P_gnorm = torch.matmul(C.transpose(2, 3),
torch.matmul(P, C))
P_gnorm = self._smooth_cov(P_gnorm,
group["cov_min"][1],
group["cov_max"][1],
group["cov_pow"][1])
#torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)
#In symbols, it solves AX = b, ... X = A^{-1} b
#and assumes A is square upper-triangular (or lower-triangular if upper= False) and does not have zeros on the diagonal.
#
# We want to invert the normalization by C to recover P (if there weren't smoothing), so we want
# P = torch.matmul(C_inv.transpose(2, 3),
# torch.matmul(P_gnorm, C_inv))
# the following is equivalent to: P_temp = torch.matmul(C_inv.transpose(2, 3), P_gnorm),
# where C_inv = C.invert()
P_temp = torch.triangular_solve(P_gnorm, C, upper=False, transpose=True)[0]
# .. now we want to do the same on the other axis, need transpose.
P = torch.triangular_solve(P_temp.transpose(2, 3), C,
upper=False, transpose=True)[0].transpose(2, 3)
if True: # TEMP
C_inv = C.inverse()
P2 = torch.matmul(C_inv.transpose(2, 3),
torch.matmul(P_gnorm, C_inv))
assert (P-P2).norm() <= 0.001 * P.norm()
P = self._apply_max_with_metric(P, G,
group["cov_max"][1])
# Apply a 3rd round of smoothing in the canonical basis.
@ -864,11 +831,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# have at this time.
eig_ceil = X.shape[-1]
# the next statement wslightly adjusts the target to be the same as
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
# give. so "max_eig" is now the target of the function for arg ==
# eig_ceil.
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
while max_eig <= eig_ceil * 0.5:
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
# increasing from 0..eig_ceil.
logging.info(f"X={X}, eig_ceil={eig_ceil}")
X = X - 0.5/eig_ceil * torch.matmul(X, X)
eig_ceil = 0.5 * eig_ceil
@ -889,6 +861,59 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
return X
def _apply_max_with_metric(self,
X: Tensor,
M: Tensor,
max_eig: float) -> Tensor:
"""
Smooths X with maximum eigenvalue (relative to the mean) relative to
metric M. Equivalent to applying
Y := M^{0.5} X M^{0.5}
Apply maximum to eigenvalues of Y, as done in _smooth_cov()
X := M^{-0.5} Y M^{-0.5}
Args:
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
"""
X = X.clone()
# mean eig of M^{0.5} X M^{0.5} ...
mean_eig = (X*M).sum(dim=(2,3), keepdim=True) / X.shape[-1]
# make sure eigs of M^{0.5} X M^{0.5} are average 1. this imposes limit on the max.
X /= mean_eig
eig_ceil = X.shape[-1]
# the next statement wslightly adjusts the target to be the same as
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
# give. so "max_eig" is now the target of the function for arg ==
# eig_ceil.
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
while max_eig <= eig_ceil * 0.5:
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
# increasing from 0..eig_ceil.
#logging.info(f"X={X}, eig_ceil={eig_ceil}")
X = X - 0.5/eig_ceil * torch.matmul(X, torch.matmul(M, X))
eig_ceil = 0.5 * eig_ceil
# max_eig > eig_ceil * 0.5
if max_eig < eig_ceil:
# map l -> l - coeff*l*l, if l==eig_ceil, this
# takes us to:
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
# == max_eig
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
# means that the function is monotonic on inputs from 0 to eig_ceil.
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
X = X - coeff * torch.matmul(X, torch.matmul(M, X))
# Normalize again.
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
return X
def _update_grad_cov(self,
group: dict,
@ -1827,7 +1852,7 @@ def _test_eve_cain():
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
hidden_dim = 300
hidden_dim = 200
m = torch.nn.Sequential(Linear(E, hidden_dim),
torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim),