Add denom_rel_eps, and set it to 1e-05

This commit is contained in:
Daniel Povey 2022-07-28 09:10:20 +08:00
parent dc565f729b
commit 8654a7385d

View File

@ -129,7 +129,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
the gradient covariance matrix used in stage (2) of smoothing. If the
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
requires SVD. Recommend to leave all these at 1.0.
eps: A general-purpose epsilon to prevent division by zero
eps: A general-purpose epsilon to prevent division by zero
denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt())
not too small relative to the mean value; this will limit the speed of update
along directions with very small gradients.
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
@ -165,6 +168,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param_rms_smooth0=0.4,
param_rms_smooth1=0.2,
eps=1.0e-08,
denom_rel_eps=1.0e-05,
param_min_rms=1.0e-05,
param_max_rms=2.0,
scalar_max=2.0,
@ -189,6 +193,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param_rms_smooth1=param_rms_smooth1,
betas=betas,
eps=eps,
denom_rel_eps=denom_rel_eps,
param_min_rms=param_min_rms,
param_max_rms=param_max_rms,
scalar_max=scalar_max,
@ -972,21 +977,22 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# [batch_idx, block_idx, canonical_coordinate, diagonalized_coordinate]
Q[:] = U_z.transpose(2, 3)
# Work out the projected smoothed parameter covariance P_proj, which is P
# projected in the basis U_z. Now,
# Work out the diagonal P_proj_diag of the projected smoothed parameter covariance P_proj, which is P
# projected in the basis U_z. This will be used to get the parameter scales for the
# bases Q_{dim}. Now,
# P = U_g P' U_g^T,
# and P_proj = U_z^T P U_z,
# so P_proj = (U_z^T U_g) P' (U_z^T U_g)^T
U_prod = torch.matmul(U_z.transpose(2, 3), U_g)
# this_P_proj shape: (batch_size, num_blocks, block_size)
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
P_proj[dim] = this_P_proj.clone().reshape(batch_size, size)
this_P_proj_diag = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
P_proj[dim] = this_P_proj_diag.clone().reshape(batch_size, size)
simple_update = True
if simple_update:
# normalize the scales in a way that preserves the Frobenius norm of the
# projected parameter deltas
P_rms = this_P_proj / _mean(this_P_proj, exclude_dims=[0], keepdim=True)
P_rms = this_P_proj_diag / _mean(this_P_proj_diag, exclude_dims=[0], keepdim=True)
scale = P_rms.unsqueeze(-1).sqrt()
Q *= scale
logging.info(f"Q rms = {(Q**2).mean().sqrt()} abs-rms = {Q.abs().mean()}")
@ -1003,7 +1009,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
U_prod.transpose(2, 3))))
G_proj_unsmoothed = _diag(torch.matmul(U_prod * G_prime.unsqueeze(-2), U_prod.transpose(2, 3)))
skip = 10 if block_size > 40 else 1
logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,0:block_size:skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,0,::skip]}, diag of unsmoothed G_proj is {G_proj_unsmoothed[0,0,::skip]}")
logging.info(f"dim={dim}, diag of P_proj is: {this_P_proj_diag[0,0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,0,::skip]}, diag of unsmoothed G_proj is {G_proj_unsmoothed[0,0,::skip]}")
if size == 500:
xyz
# P_proj won't be needed if simple_update == True.
return P_proj
@ -1315,6 +1323,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
denom_rel_eps = group["denom_rel_eps"]
step = state["step"]
grad = self._project(grad, state, forward=True)
@ -1328,7 +1337,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt() + eps
denom = exp_avg_sq.sqrt()
denom += eps + denom_rel_eps * _mean(denom, exclude_dims=[0], keepdim=True)
grad = grad / denom