mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge pull request #705 from glynpu/improve_diagnostic
[ready]show dominant parameters
This commit is contained in:
commit
1d5c03f85a
@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer):
|
|||||||
super(BatchedOptimizer, self).__init__(params, defaults)
|
super(BatchedOptimizer, self).__init__(params, defaults)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def batched_params(self, param_group):
|
def batched_params(self, param_group, group_params_names):
|
||||||
"""
|
"""
|
||||||
This function returns (technically, yields) a list of
|
This function returns (technically, yields) a list of
|
||||||
of tuples (p, state), where
|
of tuples (p, state), where
|
||||||
@ -64,31 +64,44 @@ class BatchedOptimizer(Optimizer):
|
|||||||
you can do:
|
you can do:
|
||||||
<code>
|
<code>
|
||||||
with self.batched_params(group["params"]) as batches:
|
with self.batched_params(group["params"]) as batches:
|
||||||
for p, state in batches:
|
for p, state, p_names in batches:
|
||||||
...
|
...
|
||||||
</code>
|
</code>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: a parameter group, which is a list of parameters; should be
|
group: a parameter group, which is a list of parameters; should be
|
||||||
one of self.groups.
|
one of self.param_groups.
|
||||||
|
group_params_names: name for each parameter in group,
|
||||||
|
which is List[str].
|
||||||
"""
|
"""
|
||||||
batches = defaultdict(
|
batches = defaultdict(
|
||||||
list
|
list
|
||||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||||
|
batches_names = defaultdict(
|
||||||
|
list
|
||||||
|
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||||
|
|
||||||
for p in param_group:
|
assert len(param_group) == len(group_params_names)
|
||||||
|
for p, named_p in zip(param_group, group_params_names):
|
||||||
key = (str(p.dtype), *p.shape)
|
key = (str(p.dtype), *p.shape)
|
||||||
batches[key].append(p)
|
batches[key].append(p)
|
||||||
|
batches_names[key].append(named_p)
|
||||||
|
|
||||||
|
batches_names_keys = list(batches_names.keys())
|
||||||
|
sorted_idx = sorted(
|
||||||
|
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
||||||
|
)
|
||||||
|
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
|
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
|
|
||||||
stacked_params_dict = dict()
|
stacked_params_dict = dict()
|
||||||
|
|
||||||
# turn batches into a list, in deterministic order.
|
# turn batches into a list, in deterministic order.
|
||||||
batches = [batches[key] for key in sorted(batches.keys())]
|
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
||||||
# pairs will contain pairs of (stacked_param, state), one for each batch
|
# one for each batch in `batches`.
|
||||||
# in `batches`.
|
tuples = []
|
||||||
pairs = []
|
|
||||||
|
|
||||||
for batch in batches:
|
for batch, batch_names in zip(batches, batches_names):
|
||||||
p = batch[0]
|
p = batch[0]
|
||||||
# we arbitrarily store the state in the
|
# we arbitrarily store the state in the
|
||||||
# state corresponding to the 1st parameter in the
|
# state corresponding to the 1st parameter in the
|
||||||
@ -100,11 +113,11 @@ class BatchedOptimizer(Optimizer):
|
|||||||
)
|
)
|
||||||
p_stacked.grad = grad
|
p_stacked.grad = grad
|
||||||
stacked_params_dict[key] = p_stacked
|
stacked_params_dict[key] = p_stacked
|
||||||
pairs.append((p_stacked, state))
|
tuples.append((p_stacked, state, batch_names))
|
||||||
|
|
||||||
yield pairs # <-- calling code will do the actual optimization here!
|
yield tuples # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
for ((stacked_params, _state), batch) in zip(pairs, batches):
|
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||||
for i, p in enumerate(batch): # batch is list of Parameter
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
@ -165,8 +178,15 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
scalar_max=10.0,
|
scalar_max=10.0,
|
||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
|
parameters_names=None,
|
||||||
|
show_dominant_parameters=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
assert parameters_names is not None, (
|
||||||
|
"Please prepare parameters_names,"
|
||||||
|
"which is a List[List[str]]. Each List[str] is for a group"
|
||||||
|
"and each str is for a parameter"
|
||||||
|
)
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -181,6 +201,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(params, defaults)
|
||||||
|
assert len(self.param_groups) == len(parameters_names)
|
||||||
|
self.parameters_names = parameters_names
|
||||||
|
self.show_dominant_parameters = show_dominant_parameters
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super(ScaledAdam, self).__setstate__(state)
|
super(ScaledAdam, self).__setstate__(state)
|
||||||
@ -199,9 +222,10 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
batch = True
|
batch = True
|
||||||
for group in self.param_groups:
|
|
||||||
|
|
||||||
with self.batched_params(group["params"]) as batches:
|
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||||
|
|
||||||
|
with self.batched_params(group["params"], group_params_names) as batches:
|
||||||
|
|
||||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
@ -214,7 +238,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
else:
|
else:
|
||||||
clipping_scale = self._get_clipping_scale(group, batches)
|
clipping_scale = self._get_clipping_scale(group, batches)
|
||||||
|
|
||||||
for p, state in batches:
|
for p, state, _ in batches:
|
||||||
# Perform optimization step.
|
# Perform optimization step.
|
||||||
# grad is not going to be None, we handled that when creating the batches.
|
# grad is not going to be None, we handled that when creating the batches.
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
@ -276,7 +300,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
|
||||||
def _get_clipping_scale(
|
def _get_clipping_scale(
|
||||||
self, group: dict, pairs: List[Tuple[Tensor, dict]]
|
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||||
@ -284,12 +308,16 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: the parameter group, an item in self.param_groups
|
group: the parameter group, an item in self.param_groups
|
||||||
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
|
tuples: a list of tuples of (param, state, param_names)
|
||||||
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
|
where param is a batched set of parameters,
|
||||||
|
with a .grad (1st dim is batch dim)
|
||||||
|
and state is the state-dict where optimization parameters are kept.
|
||||||
|
param_names is a List[str] while each str is name for a parameter
|
||||||
|
in batched set of parameters "param".
|
||||||
"""
|
"""
|
||||||
assert len(pairs) >= 1
|
assert len(tuples) >= 1
|
||||||
clipping_scale = group["clipping_scale"]
|
clipping_scale = group["clipping_scale"]
|
||||||
(first_p, first_state) = pairs[0]
|
(first_p, first_state, _) = tuples[0]
|
||||||
step = first_state["step"]
|
step = first_state["step"]
|
||||||
if clipping_scale is None or step == 0:
|
if clipping_scale is None or step == 0:
|
||||||
# no clipping. return early on step == 0 because the other
|
# no clipping. return early on step == 0 because the other
|
||||||
@ -298,7 +326,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for (p, state) in pairs:
|
for (p, state, param_names) in tuples:
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -361,8 +389,74 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
logging.warn(
|
logging.warn(
|
||||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||||
)
|
)
|
||||||
|
if self.show_dominant_parameters:
|
||||||
|
assert p.shape[0] == len(param_names)
|
||||||
|
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
def _show_gradient_dominating_parameter(
|
||||||
|
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Show information of parameter wihch dominanting tot_sumsq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tuples: a list of tuples of (param, state, param_names)
|
||||||
|
where param is a batched set of parameters,
|
||||||
|
with a .grad (1st dim is batch dim)
|
||||||
|
and state is the state-dict where optimization parameters are kept.
|
||||||
|
param_names is a List[str] while each str is name for a parameter
|
||||||
|
in batched set of parameters "param".
|
||||||
|
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
||||||
|
from tuples, we still pass it to save some time.
|
||||||
|
"""
|
||||||
|
all_sumsq_orig = {}
|
||||||
|
for (p, state, batch_param_names) in tuples:
|
||||||
|
# p is a stacked batch parameters.
|
||||||
|
batch_grad = p.grad
|
||||||
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
|
batch_sumsq_orig = batch_grad**2
|
||||||
|
# Dummpy values used by following `zip` statement.
|
||||||
|
batch_rms_orig = torch.ones(p.shape[0])
|
||||||
|
else:
|
||||||
|
batch_rms_orig = state["param_rms"]
|
||||||
|
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
||||||
|
dim=list(range(1, batch_grad.ndim))
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
|
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||||
|
):
|
||||||
|
|
||||||
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
|
assert torch.isclose(
|
||||||
|
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||||
|
torch.tensor(1.0),
|
||||||
|
)
|
||||||
|
sorted_by_proportion = {
|
||||||
|
k: v
|
||||||
|
for k, v in sorted(
|
||||||
|
all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
dominant_param_name = next(iter(sorted_by_proportion))
|
||||||
|
(
|
||||||
|
dominant_proportion,
|
||||||
|
dominant_sumsq,
|
||||||
|
dominant_rms,
|
||||||
|
dominant_grad,
|
||||||
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
|
logging.info(
|
||||||
|
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||||
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
|
f"={dominant_sumsq:.3e},"
|
||||||
|
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||||
|
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _step_one_batch(
|
def _step_one_batch(
|
||||||
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
||||||
):
|
):
|
||||||
|
@ -988,7 +988,16 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
|
parameters_names = []
|
||||||
|
parameters_names.append(
|
||||||
|
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||||
|
)
|
||||||
|
optimizer = ScaledAdam(
|
||||||
|
model.parameters(),
|
||||||
|
lr=params.base_lr,
|
||||||
|
clipping_scale=2.0,
|
||||||
|
parameters_names=parameters_names,
|
||||||
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user