mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Draft of new way of smoothing param_rms, diagonalized by grad
This commit is contained in:
parent
a63afe348a
commit
4da4e69fba
@ -142,10 +142,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
lr=3e-02,
|
lr=3e-02,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
param_pow=0.75,
|
param_pow=1.0,
|
||||||
param_rms_smooth0=0.75,
|
param_rms_smooth0=0.75,
|
||||||
param_rms_smooth1=0.25,
|
param_rms_smooth1=0.25,
|
||||||
max_lr_factor=4.0,
|
max_lr_factor=10.0,
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
param_min_rms=1.0e-05,
|
param_min_rms=1.0e-05,
|
||||||
param_max_rms=2.0,
|
param_max_rms=2.0,
|
||||||
@ -593,10 +593,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
try:
|
try:
|
||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
param_cov = state[f"param_cov_{dim}"]
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
assert size == 1 or size == numel, size
|
assert size == 1 or size == numel, size
|
||||||
continue # e.g. size == 1 or size == numel:
|
continue # e.g. size == 1 or size == numel
|
||||||
|
|
||||||
|
param_cov = self._get_smoothed_param_cov(group, p, state, dim)
|
||||||
|
|
||||||
# param_cov has the same shape as Q
|
# param_cov has the same shape as Q
|
||||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||||
@ -683,7 +684,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# rms shape: (batch_size, 1, size, 1, 1)
|
# rms shape: (batch_size, 1, size, 1, 1)
|
||||||
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
|
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
|
||||||
rank = numel // size
|
rank = numel // size
|
||||||
smoothed_rms = self._smooth_param_rms(group, rms, rank)
|
# we did other kinds of smoothing in _get_smoothed_param_cov
|
||||||
|
#smoothed_rms = self._smooth_param_rms(group, rms, rank)
|
||||||
|
smoothed_rms = rms ** group["param_pow"]
|
||||||
cur_scales[dim] = smoothed_rms
|
cur_scales[dim] = smoothed_rms
|
||||||
cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim..
|
cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim..
|
||||||
|
|
||||||
@ -727,6 +730,95 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
state[f"Q_{dim}"] *= scale
|
state[f"Q_{dim}"] *= scale
|
||||||
state["last_param_scale_update"] = state["step"]
|
state["last_param_scale_update"] = state["step"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_smoothed_param_cov(self,
|
||||||
|
group: dict,
|
||||||
|
p: Tensor,
|
||||||
|
state: dict,
|
||||||
|
dim: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
This function returns a modified/smoothed version of the parameter covariance
|
||||||
|
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
|
||||||
|
p, averaged over time, and taken over dimension `dim` of the tensor.
|
||||||
|
|
||||||
|
The smoothing done here limits the extend to which the parameter covariance
|
||||||
|
can be strongly "off-diagonal" with respect to the gradient covariance. That is:
|
||||||
|
if the parameter covariance is just the gradient covariance to some power, this
|
||||||
|
function does no smoothing; but if it is highly off-diagonal we do more smoothing.
|
||||||
|
"""
|
||||||
|
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
||||||
|
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
||||||
|
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
|
||||||
|
U_g, _, _ = _svd(grad_cov) # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal.
|
||||||
|
|
||||||
|
# param_cov_proj is param_cov in a different orthonormal basis, that diagonalizes
|
||||||
|
# grad_cov.
|
||||||
|
param_cov_proj = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
|
||||||
|
|
||||||
|
# param_cov_eps is probably not critical, I don't expect to see super
|
||||||
|
# small values. apply as floor in case roundoff causes negative values.
|
||||||
|
param_cov_eps = 1.0e-05
|
||||||
|
param_rms = _diag(param_cov_proj).clamp_(min=param_cov_eps).sqrt()
|
||||||
|
param_cov_inv_scale = param_rms.unsqueeze(-1) * param_rms.unsqueeze(-2)
|
||||||
|
|
||||||
|
# param_cov_norm should have diagonal values close to 1.0 (only not
|
||||||
|
# exactly 1.0 due to param_cov_eps and roundoff)
|
||||||
|
param_cov_norm = param_cov_proj / param_cov_inv_scale
|
||||||
|
|
||||||
|
# OK, this is where we do smoothing.
|
||||||
|
# decompose param_cov_norm, which is symmetric, as U_p S U_p^T
|
||||||
|
U_p, S, _ = _svd(param_cov_norm)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
residual_rms = S.sqrt()
|
||||||
|
#
|
||||||
|
relative_rms_pow = 0.7
|
||||||
|
relative_rms_max = 4.0
|
||||||
|
|
||||||
|
residual_rms = residual_rms ** relative_rms_pow
|
||||||
|
residual_rms /= _mean(residual_rms, exclude_dims=[0], keepdim=True)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# smooth according to the rank of the observation..
|
||||||
|
size = p.shape[dim]
|
||||||
|
rank = p.numel() // (size * batch_size)
|
||||||
|
smooth0 = group["param_rms_smooth0"]
|
||||||
|
smooth1 = group["param_rms_smooth1"]
|
||||||
|
# want expr to be of the form: smooth = alpha * size / (beta*rank + size)
|
||||||
|
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0.
|
||||||
|
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
|
||||||
|
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
|
||||||
|
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
|
||||||
|
|
||||||
|
mean = _mean(residual_rms, exclude_dims=[0], keepdim=True)
|
||||||
|
residual_rms += group["eps"] + smooth * mean
|
||||||
|
residual_rms = residual_rms / _mean(residual_rms, exclude_dims=[0], keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
# apply the maximum via a softmin function, softmin(x,y) = 1/(1/x + 1/y)
|
||||||
|
residual_rms = 1. / (1. / residual_rms + 1. / relative_rms_max)
|
||||||
|
|
||||||
|
if random.random() < 0.1:
|
||||||
|
skip = 10 if S.shape[-1] > 40 else 1
|
||||||
|
logging.info(f"Smoothed param_rms from {S.sqrt()[0,0,::skip]} to {residual_rms[0,0,::skip]}, param_rms={param_rms[0,0,::skip]}")
|
||||||
|
|
||||||
|
# U shape: (batch_size, num_blocks, block_size, block_size),
|
||||||
|
# interpreted as
|
||||||
|
# residual_rms shape: (batch_size, num_blocks, block_size).
|
||||||
|
# so in terms of matrix multiplication, we are computing X_p = matmul(U_p, residual_rms.diag())
|
||||||
|
X_p = U_p * residual_rms.unsqueeze(-2)
|
||||||
|
param_cov_norm_smoothed = torch.matmul(X_p, X_p.transpose(2, 3))
|
||||||
|
|
||||||
|
# Undo the scaling by the diagonal of param_cov
|
||||||
|
param_cov_proj_smoothed = param_cov_norm_smoothed * param_cov_inv_scale
|
||||||
|
|
||||||
|
# Undo the projection by U.
|
||||||
|
param_cov_smoothed = torch.matmul(U_g, torch.matmul(param_cov_proj_smoothed,
|
||||||
|
U_g.transpose(2, 3)))
|
||||||
|
return param_cov_smoothed
|
||||||
|
|
||||||
|
|
||||||
def _diagonalize_grad_cov(self,
|
def _diagonalize_grad_cov(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user