Some cleanup

This commit is contained in:
Daniel Povey 2022-07-24 04:22:11 +08:00
parent ddceb7963b
commit 33ffd17515

View File

@ -144,7 +144,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_lr_scale=0.1,
min_lr_factor=(0.05, 0.01, 0.01),
max_lr_factor=(10.0, 40.0, 10.0),
#param_pow=(0.99999, 0.99999, 0.99999),
param_pow=(1.0, 1.0, 1.0),
param_rms_smooth0=0.4,
param_rms_smooth1=0.2,
@ -154,7 +153,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
scalar_max=2.0,
size_update_period=4,
lr_update_period=(200, 1000),
grad_cov_period=3,
# grad_cov_period does not affect speed very much. If __name__ ==
# "__main__", keeping this coprime to 20 which is the number of
# sample pairs used in the test code
grad_cov_period=(3 if __name__ == "__main__" else 5),
max_block_size=1024,
):
@ -464,7 +466,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1-beta1))
@ -891,12 +892,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# A check.
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
diff_ratio = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
if diff_ratio > 0.001:
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}, size={size}")
diff_ratio_l2 = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum()
if diff_ratio_l2 > 0.00001 or diff_ratio_l1 > 0.03:
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}")
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
# We just need the basis that diagonalizes this.
# We just need the basis U_z that diagonalizes this.
U_z, S, _ = _svd(Z)
if True:
skip = 10 if S.shape[-1] > 40 else 1
@ -1268,7 +1270,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
alpha = -lr * (1-beta1) / denom
delta = state["delta"]
delta.add_(grad / denom, alpha=-lr*(1-beta1))