mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Applying max to G with noinv method with metric.
This commit is contained in:
parent
2042c9862c
commit
d84a2e22e3
@ -163,7 +163,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr=3e-02,
|
||||
betas=(0.9, 0.98),
|
||||
size_lr_scale=0.1,
|
||||
cov_min=(0.025, 0.0025, 0.02, 0.0001),
|
||||
cov_min=(0.025, 0.0, 0.02, 0.0001),
|
||||
cov_max=(10.0, 80.0, 5.0, 400.0),
|
||||
cov_pow=(1.0, 1.0, 1.0, 1.0),
|
||||
param_rms_smooth0=0.4,
|
||||
@ -669,7 +669,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
if True:
|
||||
_,S,_ = _svd(P)
|
||||
logging.info(f"Eigs of P are {S}")
|
||||
logging.info(f"Eigs of P are {S[0][0]}")
|
||||
|
||||
|
||||
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
|
||||
@ -759,49 +759,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
|
||||
P_diag += smooth * P_diag_mean
|
||||
|
||||
G = G.clone()
|
||||
G_diag = _diag(G) # aliased
|
||||
G_diag *= 1.005 # improve its condition, for numerical reasons.
|
||||
#G = G.clone()
|
||||
#G_diag = _diag(G) # aliased
|
||||
G = self._smooth_cov(G,
|
||||
group["cov_min"][3],
|
||||
group["cov_max"][3],
|
||||
group["cov_pow"][3])
|
||||
|
||||
# C C^T == G.
|
||||
C = G.cholesky()
|
||||
|
||||
P_orig = P.clone()
|
||||
|
||||
# treat the last dim of C as being in an arbitrary space, its next-to-last dim
|
||||
# is the "canonical" one that we need to sum with the dims of P.
|
||||
P_gnorm = torch.matmul(C.transpose(2, 3),
|
||||
torch.matmul(P, C))
|
||||
|
||||
P_gnorm = self._smooth_cov(P_gnorm,
|
||||
group["cov_min"][1],
|
||||
group["cov_max"][1],
|
||||
group["cov_pow"][1])
|
||||
|
||||
#torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)
|
||||
#In symbols, it solves AX = b, ... X = A^{-1} b
|
||||
#and assumes A is square upper-triangular (or lower-triangular if upper= False) and does not have zeros on the diagonal.
|
||||
#
|
||||
# We want to invert the normalization by C to recover P (if there weren't smoothing), so we want
|
||||
# P = torch.matmul(C_inv.transpose(2, 3),
|
||||
# torch.matmul(P_gnorm, C_inv))
|
||||
|
||||
# the following is equivalent to: P_temp = torch.matmul(C_inv.transpose(2, 3), P_gnorm),
|
||||
# where C_inv = C.invert()
|
||||
P_temp = torch.triangular_solve(P_gnorm, C, upper=False, transpose=True)[0]
|
||||
# .. now we want to do the same on the other axis, need transpose.
|
||||
P = torch.triangular_solve(P_temp.transpose(2, 3), C,
|
||||
upper=False, transpose=True)[0].transpose(2, 3)
|
||||
|
||||
if True: # TEMP
|
||||
C_inv = C.inverse()
|
||||
P2 = torch.matmul(C_inv.transpose(2, 3),
|
||||
torch.matmul(P_gnorm, C_inv))
|
||||
assert (P-P2).norm() <= 0.001 * P.norm()
|
||||
P = self._apply_max_with_metric(P, G,
|
||||
group["cov_max"][1])
|
||||
|
||||
|
||||
# Apply a 3rd round of smoothing in the canonical basis.
|
||||
@ -864,11 +831,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# have at this time.
|
||||
eig_ceil = X.shape[-1]
|
||||
|
||||
# the next statement wslightly adjusts the target to be the same as
|
||||
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
|
||||
# give. so "max_eig" is now the target of the function for arg ==
|
||||
# eig_ceil.
|
||||
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
|
||||
|
||||
while max_eig <= eig_ceil * 0.5:
|
||||
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
|
||||
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
|
||||
# increasing from 0..eig_ceil.
|
||||
logging.info(f"X={X}, eig_ceil={eig_ceil}")
|
||||
X = X - 0.5/eig_ceil * torch.matmul(X, X)
|
||||
eig_ceil = 0.5 * eig_ceil
|
||||
|
||||
@ -889,6 +861,59 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return X
|
||||
|
||||
|
||||
def _apply_max_with_metric(self,
|
||||
X: Tensor,
|
||||
M: Tensor,
|
||||
max_eig: float) -> Tensor:
|
||||
"""
|
||||
Smooths X with maximum eigenvalue (relative to the mean) relative to
|
||||
metric M. Equivalent to applying
|
||||
Y := M^{0.5} X M^{0.5}
|
||||
Apply maximum to eigenvalues of Y, as done in _smooth_cov()
|
||||
X := M^{-0.5} Y M^{-0.5}
|
||||
|
||||
Args:
|
||||
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
|
||||
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
|
||||
"""
|
||||
X = X.clone()
|
||||
# mean eig of M^{0.5} X M^{0.5} ...
|
||||
mean_eig = (X*M).sum(dim=(2,3), keepdim=True) / X.shape[-1]
|
||||
# make sure eigs of M^{0.5} X M^{0.5} are average 1. this imposes limit on the max.
|
||||
X /= mean_eig
|
||||
|
||||
eig_ceil = X.shape[-1]
|
||||
|
||||
# the next statement wslightly adjusts the target to be the same as
|
||||
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
|
||||
# give. so "max_eig" is now the target of the function for arg ==
|
||||
# eig_ceil.
|
||||
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
|
||||
|
||||
while max_eig <= eig_ceil * 0.5:
|
||||
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
|
||||
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
|
||||
# increasing from 0..eig_ceil.
|
||||
#logging.info(f"X={X}, eig_ceil={eig_ceil}")
|
||||
X = X - 0.5/eig_ceil * torch.matmul(X, torch.matmul(M, X))
|
||||
eig_ceil = 0.5 * eig_ceil
|
||||
|
||||
# max_eig > eig_ceil * 0.5
|
||||
if max_eig < eig_ceil:
|
||||
# map l -> l - coeff*l*l, if l==eig_ceil, this
|
||||
# takes us to:
|
||||
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
|
||||
# == max_eig
|
||||
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
|
||||
# means that the function is monotonic on inputs from 0 to eig_ceil.
|
||||
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
|
||||
X = X - coeff * torch.matmul(X, torch.matmul(M, X))
|
||||
|
||||
# Normalize again.
|
||||
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||
return X
|
||||
|
||||
|
||||
def _update_grad_cov(self,
|
||||
group: dict,
|
||||
@ -1827,7 +1852,7 @@ def _test_eve_cain():
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
|
||||
hidden_dim = 300
|
||||
hidden_dim = 200
|
||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, hidden_dim),
|
||||
|
Loading…
x
Reference in New Issue
Block a user