diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index a0d9eacc9..65cfeedc3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -144,7 +144,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size_lr_scale=0.1, min_lr_factor=(0.05, 0.01, 0.01), max_lr_factor=(10.0, 40.0, 10.0), - #param_pow=(0.99999, 0.99999, 0.99999), param_pow=(1.0, 1.0, 1.0), param_rms_smooth0=0.4, param_rms_smooth1=0.2, @@ -154,7 +153,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of scalar_max=2.0, size_update_period=4, lr_update_period=(200, 1000), - grad_cov_period=3, + # grad_cov_period does not affect speed very much. If __name__ == + # "__main__", keeping this coprime to 20 which is the number of + # sample pairs used in the test code + grad_cov_period=(3 if __name__ == "__main__" else 5), max_block_size=1024, ): @@ -464,7 +466,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of scale_step.masked_fill_(is_too_small, size_lr * size_update_period) scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) - delta = state["delta"] # the factor of (1-beta1) relates to momentum. delta.add_(p * scale_step, alpha=(1-beta1)) @@ -891,12 +892,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # A check. # Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2) P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime) - diff_ratio = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt() - if diff_ratio > 0.001: - logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}, size={size}") + diff_ratio_l2 = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt() + diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum() + if diff_ratio_l2 > 0.00001 or diff_ratio_l1 > 0.03: + logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}") Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3))) # OK, Z is the SPD transform that maps G to P, as in Z G Z = P. - # We just need the basis that diagonalizes this. + # We just need the basis U_z that diagonalizes this. U_z, S, _ = _svd(Z) if True: skip = 10 if S.shape[-1] > 40 else 1 @@ -1268,7 +1270,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # slower update at the start will help stability anyway. bias_correction2 = 1 - beta2 ** (state["step"] + 1) denom = (exp_avg_sq / bias_correction2).sqrt() + eps - alpha = -lr * (1-beta1) / denom delta = state["delta"] delta.add_(grad / denom, alpha=-lr*(1-beta1))