mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
black formatted
This commit is contained in:
parent
cab8c73c63
commit
6dd7fc46d2
@ -300,8 +300,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# the input is groups of parameter or named parameter.
|
||||
for cur_group in iterable_or_groups:
|
||||
assert "named_params" in cur_group
|
||||
name_list = [ x[0] for x in cur_group["named_params"] ]
|
||||
p_list = [ x[1] for x in cur_group["named_params"] ]
|
||||
name_list = [x[0] for x in cur_group["named_params"]]
|
||||
p_list = [x[1] for x in cur_group["named_params"]]
|
||||
del cur_group["named_params"]
|
||||
cur_group["params"] = p_list
|
||||
param_groups.append(cur_group)
|
||||
@ -437,7 +437,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() * (scalar_lr_scale ** 2) # sum() to change shape [1] to []
|
||||
tot_sumsq += (grad**2).sum() * (
|
||||
scalar_lr_scale**2
|
||||
) # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||
|
||||
@ -448,7 +450,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
irregular_estimate_steps = [ i for i in [ 10, 20, 40 ] if i < clipping_update_period ]
|
||||
irregular_estimate_steps = [
|
||||
i for i in [10, 20, 40] if i < clipping_update_period
|
||||
]
|
||||
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
|
||||
# Print some stats.
|
||||
# We don't reach here if step == 0 because we would have returned
|
||||
@ -459,9 +463,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
num_norms = sorted_norms.numel()
|
||||
quartiles = []
|
||||
for n in range(0, 5):
|
||||
index = min(
|
||||
num_norms - 1, (num_norms // 4) * n
|
||||
)
|
||||
index = min(num_norms - 1, (num_norms // 4) * n)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
@ -499,18 +501,21 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq,
|
||||
group["scalar_lr_scale"])
|
||||
self._show_gradient_dominating_parameter(
|
||||
tuples, tot_sumsq, group["scalar_lr_scale"]
|
||||
)
|
||||
|
||||
if ans == 0.0:
|
||||
for (p, state, param_names) in tuples:
|
||||
p.grad.zero_() # get rid of infinity()
|
||||
p.grad.zero_() # get rid of infinity()
|
||||
|
||||
return ans
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor,
|
||||
scalar_lr_scale: float,
|
||||
self,
|
||||
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||
tot_sumsq: Tensor,
|
||||
scalar_lr_scale: float,
|
||||
):
|
||||
"""
|
||||
Show information of parameter which dominates tot_sumsq.
|
||||
@ -531,11 +536,12 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
# Dummy values used by following `zip` statement.
|
||||
batch_rms_orig = torch.full(p.shape, scalar_lr_scale,
|
||||
device=batch_grad.device)
|
||||
batch_rms_orig = torch.full(
|
||||
p.shape, scalar_lr_scale, device=batch_grad.device
|
||||
)
|
||||
else:
|
||||
batch_rms_orig = state["param_rms"]
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2)
|
||||
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
|
||||
if batch_grad.ndim > 1:
|
||||
# need to guard it with if-statement because sum() sums over
|
||||
# all dims if dim == ().
|
||||
@ -679,8 +685,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# We have to look at the trained model for parameters at or around the
|
||||
# param_max_rms, because sometimes they can indicate a problem with the
|
||||
# topology or settings.
|
||||
scale_step = torch.minimum(scale_step,
|
||||
(param_max_rms - param_rms) / param_rms)
|
||||
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
@ -897,7 +902,8 @@ class Eden(LRScheduler):
|
||||
warmup_factor = (
|
||||
1.0
|
||||
if self.batch >= self.warmup_batches
|
||||
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||
else self.warmup_start
|
||||
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||
)
|
||||
|
||||
@ -946,7 +952,8 @@ class Eden2(LRScheduler):
|
||||
warmup_factor = (
|
||||
1.0
|
||||
if self.batch >= self.warmup_batches
|
||||
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||
else self.warmup_start
|
||||
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user