mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
impove comment
This commit is contained in:
parent
9cf79cac3f
commit
4fee3e7f1e
@ -64,13 +64,15 @@ class BatchedOptimizer(Optimizer):
|
||||
you can do:
|
||||
<code>
|
||||
with self.batched_params(group["params"]) as batches:
|
||||
for p, state in batches:
|
||||
for p, state, p_names in batches:
|
||||
...
|
||||
</code>
|
||||
|
||||
Args:
|
||||
group: a parameter group, which is a list of parameters; should be
|
||||
one of self.groups.
|
||||
one of self.param_groups.
|
||||
group_params_names: name for each parameter in group,
|
||||
which is List[str].
|
||||
"""
|
||||
batches = defaultdict(
|
||||
list
|
||||
@ -79,6 +81,7 @@ class BatchedOptimizer(Optimizer):
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
|
||||
assert len(param_group) == len(group_params_names)
|
||||
for p, named_p in zip(param_group, group_params_names):
|
||||
key = (str(p.dtype), *p.shape)
|
||||
batches[key].append(p)
|
||||
@ -94,9 +97,9 @@ class BatchedOptimizer(Optimizer):
|
||||
stacked_params_dict = dict()
|
||||
|
||||
# turn batches into a list, in deterministic order.
|
||||
# pairs will contain pairs of (stacked_param, state), one for each batch
|
||||
# in `batches`.
|
||||
pairs = []
|
||||
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
||||
# one for each batch in `batches`.
|
||||
tuples = []
|
||||
|
||||
for batch, batch_names in zip(batches, batches_names):
|
||||
p = batch[0]
|
||||
@ -110,11 +113,11 @@ class BatchedOptimizer(Optimizer):
|
||||
)
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
pairs.append((p_stacked, state, batch_names))
|
||||
tuples.append((p_stacked, state, batch_names))
|
||||
|
||||
yield pairs # <-- calling code will do the actual optimization here!
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state, _names), batch) in zip(pairs, batches):
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@ -179,6 +182,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
show_dominant_parameters=True,
|
||||
):
|
||||
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
"which is a List[List[str]]. Each List[str] is for a group"
|
||||
"and each str is for a parameter"
|
||||
)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -193,6 +201,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
assert len(self.param_groups) == len(parameters_names)
|
||||
self.parameters_names = parameters_names
|
||||
self.show_dominant_parameters = show_dominant_parameters
|
||||
|
||||
@ -213,7 +222,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
loss = closure()
|
||||
|
||||
batch = True
|
||||
assert len(self.param_groups) == len(self.parameters_names)
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
|
||||
@ -292,7 +300,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
def _get_clipping_scale(
|
||||
self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]]
|
||||
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||
) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
@ -300,12 +308,16 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
Args:
|
||||
group: the parameter group, an item in self.param_groups
|
||||
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
|
||||
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
where param is a batched set of parameters,
|
||||
with a .grad (1st dim is batch dim)
|
||||
and state is the state-dict where optimization parameters are kept.
|
||||
param_names is a List[str] while each str is name for a parameter
|
||||
in batched set of parameters "param".
|
||||
"""
|
||||
assert len(pairs) >= 1
|
||||
assert len(tuples) >= 1
|
||||
clipping_scale = group["clipping_scale"]
|
||||
(first_p, first_state, _) = pairs[0]
|
||||
(first_p, first_state, _) = tuples[0]
|
||||
step = first_state["step"]
|
||||
if clipping_scale is None or step == 0:
|
||||
# no clipping. return early on step == 0 because the other
|
||||
@ -314,7 +326,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state, param_names) in pairs:
|
||||
for (p, state, param_names) in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
@ -379,12 +391,27 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(pairs, tot_sumsq)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
return ans
|
||||
|
||||
def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
|
||||
def _show_gradient_dominating_parameter(
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||
):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
|
||||
Args:
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
where param is a batched set of parameters,
|
||||
with a .grad (1st dim is batch dim)
|
||||
and state is the state-dict where optimization parameters are kept.
|
||||
param_names is a List[str] while each str is name for a parameter
|
||||
in batched set of parameters "param".
|
||||
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in pairs:
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
|
@ -989,9 +989,15 @@ def run(rank, world_size, args):
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
parameters_names = []
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
|
||||
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
|
||||
clipping_scale=2.0, parameters_names=parameters_names)
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user