black formatted

This commit is contained in:
JinZr 2023-10-31 19:55:29 +08:00
parent cab8c73c63
commit 6dd7fc46d2

View File

@ -437,7 +437,9 @@ class ScaledAdam(BatchedOptimizer):
"ScaledAdam optimizer does not support sparse gradients" "ScaledAdam optimizer does not support sparse gradients"
) )
if p.numel() == p.shape[0]: # a batch of scalars if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() * (scalar_lr_scale ** 2) # sum() to change shape [1] to [] tot_sumsq += (grad**2).sum() * (
scalar_lr_scale**2
) # sum() to change shape [1] to []
else: else:
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
@ -448,7 +450,9 @@ class ScaledAdam(BatchedOptimizer):
) )
first_state["model_norms"][step % clipping_update_period] = tot_norm first_state["model_norms"][step % clipping_update_period] = tot_norm
irregular_estimate_steps = [ i for i in [ 10, 20, 40 ] if i < clipping_update_period ] irregular_estimate_steps = [
i for i in [10, 20, 40] if i < clipping_update_period
]
if step % clipping_update_period == 0 or step in irregular_estimate_steps: if step % clipping_update_period == 0 or step in irregular_estimate_steps:
# Print some stats. # Print some stats.
# We don't reach here if step == 0 because we would have returned # We don't reach here if step == 0 because we would have returned
@ -459,9 +463,7 @@ class ScaledAdam(BatchedOptimizer):
num_norms = sorted_norms.numel() num_norms = sorted_norms.numel()
quartiles = [] quartiles = []
for n in range(0, 5): for n in range(0, 5):
index = min( index = min(num_norms - 1, (num_norms // 4) * n)
num_norms - 1, (num_norms // 4) * n
)
quartiles.append(sorted_norms[index].item()) quartiles.append(sorted_norms[index].item())
median = quartiles[2] median = quartiles[2]
@ -499,8 +501,9 @@ class ScaledAdam(BatchedOptimizer):
) )
if self.show_dominant_parameters: if self.show_dominant_parameters:
assert p.shape[0] == len(param_names) assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq, self._show_gradient_dominating_parameter(
group["scalar_lr_scale"]) tuples, tot_sumsq, group["scalar_lr_scale"]
)
if ans == 0.0: if ans == 0.0:
for (p, state, param_names) in tuples: for (p, state, param_names) in tuples:
@ -509,7 +512,9 @@ class ScaledAdam(BatchedOptimizer):
return ans return ans
def _show_gradient_dominating_parameter( def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor, self,
tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor,
scalar_lr_scale: float, scalar_lr_scale: float,
): ):
""" """
@ -531,11 +536,12 @@ class ScaledAdam(BatchedOptimizer):
batch_grad = p.grad batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars if p.numel() == p.shape[0]: # a batch of scalars
# Dummy values used by following `zip` statement. # Dummy values used by following `zip` statement.
batch_rms_orig = torch.full(p.shape, scalar_lr_scale, batch_rms_orig = torch.full(
device=batch_grad.device) p.shape, scalar_lr_scale, device=batch_grad.device
)
else: else:
batch_rms_orig = state["param_rms"] batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2) batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
if batch_grad.ndim > 1: if batch_grad.ndim > 1:
# need to guard it with if-statement because sum() sums over # need to guard it with if-statement because sum() sums over
# all dims if dim == (). # all dims if dim == ().
@ -679,8 +685,7 @@ class ScaledAdam(BatchedOptimizer):
# We have to look at the trained model for parameters at or around the # We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the # param_max_rms, because sometimes they can indicate a problem with the
# topology or settings. # topology or settings.
scale_step = torch.minimum(scale_step, scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
(param_max_rms - param_rms) / param_rms)
delta = state["delta"] delta = state["delta"]
# the factor of (1-beta1) relates to momentum. # the factor of (1-beta1) relates to momentum.
@ -897,7 +902,8 @@ class Eden(LRScheduler):
warmup_factor = ( warmup_factor = (
1.0 1.0
if self.batch >= self.warmup_batches if self.batch >= self.warmup_batches
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) else self.warmup_start
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
# else 0.5 + 0.5 * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
) )
@ -946,7 +952,8 @@ class Eden2(LRScheduler):
warmup_factor = ( warmup_factor = (
1.0 1.0
if self.batch >= self.warmup_batches if self.batch >= self.warmup_batches
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) else self.warmup_start
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
# else 0.5 + 0.5 * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
) )