mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
black formatted
This commit is contained in:
parent
cab8c73c63
commit
6dd7fc46d2
@ -300,8 +300,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# the input is groups of parameter or named parameter.
|
# the input is groups of parameter or named parameter.
|
||||||
for cur_group in iterable_or_groups:
|
for cur_group in iterable_or_groups:
|
||||||
assert "named_params" in cur_group
|
assert "named_params" in cur_group
|
||||||
name_list = [ x[0] for x in cur_group["named_params"] ]
|
name_list = [x[0] for x in cur_group["named_params"]]
|
||||||
p_list = [ x[1] for x in cur_group["named_params"] ]
|
p_list = [x[1] for x in cur_group["named_params"]]
|
||||||
del cur_group["named_params"]
|
del cur_group["named_params"]
|
||||||
cur_group["params"] = p_list
|
cur_group["params"] = p_list
|
||||||
param_groups.append(cur_group)
|
param_groups.append(cur_group)
|
||||||
@ -437,7 +437,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
"ScaledAdam optimizer does not support sparse gradients"
|
"ScaledAdam optimizer does not support sparse gradients"
|
||||||
)
|
)
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
tot_sumsq += (grad**2).sum() * (scalar_lr_scale ** 2) # sum() to change shape [1] to []
|
tot_sumsq += (grad**2).sum() * (
|
||||||
|
scalar_lr_scale**2
|
||||||
|
) # sum() to change shape [1] to []
|
||||||
else:
|
else:
|
||||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||||
|
|
||||||
@ -448,7 +450,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
)
|
)
|
||||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||||
|
|
||||||
irregular_estimate_steps = [ i for i in [ 10, 20, 40 ] if i < clipping_update_period ]
|
irregular_estimate_steps = [
|
||||||
|
i for i in [10, 20, 40] if i < clipping_update_period
|
||||||
|
]
|
||||||
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
|
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
|
||||||
# Print some stats.
|
# Print some stats.
|
||||||
# We don't reach here if step == 0 because we would have returned
|
# We don't reach here if step == 0 because we would have returned
|
||||||
@ -459,9 +463,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
num_norms = sorted_norms.numel()
|
num_norms = sorted_norms.numel()
|
||||||
quartiles = []
|
quartiles = []
|
||||||
for n in range(0, 5):
|
for n in range(0, 5):
|
||||||
index = min(
|
index = min(num_norms - 1, (num_norms // 4) * n)
|
||||||
num_norms - 1, (num_norms // 4) * n
|
|
||||||
)
|
|
||||||
quartiles.append(sorted_norms[index].item())
|
quartiles.append(sorted_norms[index].item())
|
||||||
|
|
||||||
median = quartiles[2]
|
median = quartiles[2]
|
||||||
@ -499,18 +501,21 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
)
|
)
|
||||||
if self.show_dominant_parameters:
|
if self.show_dominant_parameters:
|
||||||
assert p.shape[0] == len(param_names)
|
assert p.shape[0] == len(param_names)
|
||||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq,
|
self._show_gradient_dominating_parameter(
|
||||||
group["scalar_lr_scale"])
|
tuples, tot_sumsq, group["scalar_lr_scale"]
|
||||||
|
)
|
||||||
|
|
||||||
if ans == 0.0:
|
if ans == 0.0:
|
||||||
for (p, state, param_names) in tuples:
|
for (p, state, param_names) in tuples:
|
||||||
p.grad.zero_() # get rid of infinity()
|
p.grad.zero_() # get rid of infinity()
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def _show_gradient_dominating_parameter(
|
def _show_gradient_dominating_parameter(
|
||||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor,
|
self,
|
||||||
scalar_lr_scale: float,
|
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||||
|
tot_sumsq: Tensor,
|
||||||
|
scalar_lr_scale: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Show information of parameter which dominates tot_sumsq.
|
Show information of parameter which dominates tot_sumsq.
|
||||||
@ -531,11 +536,12 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
# Dummy values used by following `zip` statement.
|
# Dummy values used by following `zip` statement.
|
||||||
batch_rms_orig = torch.full(p.shape, scalar_lr_scale,
|
batch_rms_orig = torch.full(
|
||||||
device=batch_grad.device)
|
p.shape, scalar_lr_scale, device=batch_grad.device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch_rms_orig = state["param_rms"]
|
batch_rms_orig = state["param_rms"]
|
||||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2)
|
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
|
||||||
if batch_grad.ndim > 1:
|
if batch_grad.ndim > 1:
|
||||||
# need to guard it with if-statement because sum() sums over
|
# need to guard it with if-statement because sum() sums over
|
||||||
# all dims if dim == ().
|
# all dims if dim == ().
|
||||||
@ -679,8 +685,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# We have to look at the trained model for parameters at or around the
|
# We have to look at the trained model for parameters at or around the
|
||||||
# param_max_rms, because sometimes they can indicate a problem with the
|
# param_max_rms, because sometimes they can indicate a problem with the
|
||||||
# topology or settings.
|
# topology or settings.
|
||||||
scale_step = torch.minimum(scale_step,
|
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||||
(param_max_rms - param_rms) / param_rms)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
# the factor of (1-beta1) relates to momentum.
|
# the factor of (1-beta1) relates to momentum.
|
||||||
@ -897,7 +902,8 @@ class Eden(LRScheduler):
|
|||||||
warmup_factor = (
|
warmup_factor = (
|
||||||
1.0
|
1.0
|
||||||
if self.batch >= self.warmup_batches
|
if self.batch >= self.warmup_batches
|
||||||
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
else self.warmup_start
|
||||||
|
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -946,7 +952,8 @@ class Eden2(LRScheduler):
|
|||||||
warmup_factor = (
|
warmup_factor = (
|
||||||
1.0
|
1.0
|
||||||
if self.batch >= self.warmup_batches
|
if self.batch >= self.warmup_batches
|
||||||
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
else self.warmup_start
|
||||||
|
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user