Some cleanup
This commit is contained in:
parent
ddceb7963b
commit
33ffd17515
@ -144,7 +144,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
min_lr_factor=(0.05, 0.01, 0.01),
|
min_lr_factor=(0.05, 0.01, 0.01),
|
||||||
max_lr_factor=(10.0, 40.0, 10.0),
|
max_lr_factor=(10.0, 40.0, 10.0),
|
||||||
#param_pow=(0.99999, 0.99999, 0.99999),
|
|
||||||
param_pow=(1.0, 1.0, 1.0),
|
param_pow=(1.0, 1.0, 1.0),
|
||||||
param_rms_smooth0=0.4,
|
param_rms_smooth0=0.4,
|
||||||
param_rms_smooth1=0.2,
|
param_rms_smooth1=0.2,
|
||||||
@ -154,7 +153,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
scalar_max=2.0,
|
scalar_max=2.0,
|
||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
lr_update_period=(200, 1000),
|
lr_update_period=(200, 1000),
|
||||||
grad_cov_period=3,
|
# grad_cov_period does not affect speed very much. If __name__ ==
|
||||||
|
# "__main__", keeping this coprime to 20 which is the number of
|
||||||
|
# sample pairs used in the test code
|
||||||
|
grad_cov_period=(3 if __name__ == "__main__" else 5),
|
||||||
max_block_size=1024,
|
max_block_size=1024,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -464,7 +466,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
|
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
|
||||||
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
# the factor of (1-beta1) relates to momentum.
|
# the factor of (1-beta1) relates to momentum.
|
||||||
delta.add_(p * scale_step, alpha=(1-beta1))
|
delta.add_(p * scale_step, alpha=(1-beta1))
|
||||||
@ -891,12 +892,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# A check.
|
# A check.
|
||||||
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
|
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
|
||||||
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
|
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
|
||||||
diff_ratio = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
|
diff_ratio_l2 = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
|
||||||
if diff_ratio > 0.001:
|
diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum()
|
||||||
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}, size={size}")
|
if diff_ratio_l2 > 0.00001 or diff_ratio_l1 > 0.03:
|
||||||
|
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}")
|
||||||
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
|
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
|
||||||
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
|
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
|
||||||
# We just need the basis that diagonalizes this.
|
# We just need the basis U_z that diagonalizes this.
|
||||||
U_z, S, _ = _svd(Z)
|
U_z, S, _ = _svd(Z)
|
||||||
if True:
|
if True:
|
||||||
skip = 10 if S.shape[-1] > 40 else 1
|
skip = 10 if S.shape[-1] > 40 else 1
|
||||||
@ -1268,7 +1270,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# slower update at the start will help stability anyway.
|
# slower update at the start will help stability anyway.
|
||||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||||
alpha = -lr * (1-beta1) / denom
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.add_(grad / denom, alpha=-lr*(1-beta1))
|
delta.add_(grad / denom, alpha=-lr*(1-beta1))
|
||||||
|
|||||||
Reference in New Issue
Block a user