black formatted

This commit is contained in:
JinZr 2023-10-31 19:55:29 +08:00
parent cab8c73c63
commit 6dd7fc46d2

View File

@ -300,8 +300,8 @@ class ScaledAdam(BatchedOptimizer):
# the input is groups of parameter or named parameter.
for cur_group in iterable_or_groups:
assert "named_params" in cur_group
name_list = [ x[0] for x in cur_group["named_params"] ]
p_list = [ x[1] for x in cur_group["named_params"] ]
name_list = [x[0] for x in cur_group["named_params"]]
p_list = [x[1] for x in cur_group["named_params"]]
del cur_group["named_params"]
cur_group["params"] = p_list
param_groups.append(cur_group)
@ -437,7 +437,9 @@ class ScaledAdam(BatchedOptimizer):
"ScaledAdam optimizer does not support sparse gradients"
)
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() * (scalar_lr_scale ** 2) # sum() to change shape [1] to []
tot_sumsq += (grad**2).sum() * (
scalar_lr_scale**2
) # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
@ -448,7 +450,9 @@ class ScaledAdam(BatchedOptimizer):
)
first_state["model_norms"][step % clipping_update_period] = tot_norm
irregular_estimate_steps = [ i for i in [ 10, 20, 40 ] if i < clipping_update_period ]
irregular_estimate_steps = [
i for i in [10, 20, 40] if i < clipping_update_period
]
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
# Print some stats.
# We don't reach here if step == 0 because we would have returned
@ -459,9 +463,7 @@ class ScaledAdam(BatchedOptimizer):
num_norms = sorted_norms.numel()
quartiles = []
for n in range(0, 5):
index = min(
num_norms - 1, (num_norms // 4) * n
)
index = min(num_norms - 1, (num_norms // 4) * n)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
@ -499,18 +501,21 @@ class ScaledAdam(BatchedOptimizer):
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq,
group["scalar_lr_scale"])
self._show_gradient_dominating_parameter(
tuples, tot_sumsq, group["scalar_lr_scale"]
)
if ans == 0.0:
for (p, state, param_names) in tuples:
p.grad.zero_() # get rid of infinity()
p.grad.zero_() # get rid of infinity()
return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor,
scalar_lr_scale: float,
self,
tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor,
scalar_lr_scale: float,
):
"""
Show information of parameter which dominates tot_sumsq.
@ -531,11 +536,12 @@ class ScaledAdam(BatchedOptimizer):
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
# Dummy values used by following `zip` statement.
batch_rms_orig = torch.full(p.shape, scalar_lr_scale,
device=batch_grad.device)
batch_rms_orig = torch.full(
p.shape, scalar_lr_scale, device=batch_grad.device
)
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2)
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
if batch_grad.ndim > 1:
# need to guard it with if-statement because sum() sums over
# all dims if dim == ().
@ -679,8 +685,7 @@ class ScaledAdam(BatchedOptimizer):
# We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the
# topology or settings.
scale_step = torch.minimum(scale_step,
(param_max_rms - param_rms) / param_rms)
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
@ -897,7 +902,8 @@ class Eden(LRScheduler):
warmup_factor = (
1.0
if self.batch >= self.warmup_batches
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
else self.warmup_start
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
)
@ -946,7 +952,8 @@ class Eden2(LRScheduler):
warmup_factor = (
1.0
if self.batch >= self.warmup_batches
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
else self.warmup_start
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
)