mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
message formatting
This commit is contained in:
parent
89c3982a07
commit
9cf79cac3f
@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer):
|
|||||||
super(BatchedOptimizer, self).__init__(params, defaults)
|
super(BatchedOptimizer, self).__init__(params, defaults)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def batched_params(self, param_group, group_params_names=None):
|
def batched_params(self, param_group, group_params_names):
|
||||||
"""
|
"""
|
||||||
This function returns (technically, yields) a list of
|
This function returns (technically, yields) a list of
|
||||||
of tuples (p, state), where
|
of tuples (p, state), where
|
||||||
@ -85,7 +85,9 @@ class BatchedOptimizer(Optimizer):
|
|||||||
batches_names[key].append(named_p)
|
batches_names[key].append(named_p)
|
||||||
|
|
||||||
batches_names_keys = list(batches_names.keys())
|
batches_names_keys = list(batches_names.keys())
|
||||||
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
sorted_idx = sorted(
|
||||||
|
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
||||||
|
)
|
||||||
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
|
|
||||||
@ -174,7 +176,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
parameters_names=None,
|
parameters_names=None,
|
||||||
show_dominant_parameters=False,
|
show_dominant_parameters=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
@ -381,42 +383,52 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
|
def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
|
||||||
# ori means calculated with state["param_rms"]
|
all_sumsq_orig = {}
|
||||||
# cur means calculated with "param_rms" of current param.
|
|
||||||
# bt is short batch
|
|
||||||
# all_sumsq_ori_rms
|
|
||||||
all_sumsq_ori = {}
|
|
||||||
all_sumsq_cur = {}
|
|
||||||
for (p, state, batch_param_names) in pairs:
|
for (p, state, batch_param_names) in pairs:
|
||||||
# p is a stacked batch parameters.
|
# p is a stacked batch parameters.
|
||||||
grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
batch_sumsq_ori = grad**2 # sum() to change shape [1] to []
|
batch_sumsq_orig = batch_grad**2
|
||||||
batch_sumsq_cur = batch_sumsq_ori # sum() to change shape [1] to []
|
|
||||||
# Dummpy values used by following `zip` statement.
|
# Dummpy values used by following `zip` statement.
|
||||||
batch_rms_ori = torch.zeros(p.shape[0])
|
batch_rms_orig = torch.ones(p.shape[0])
|
||||||
batch_rms_cur = batch_rms_ori
|
|
||||||
else:
|
else:
|
||||||
batch_rms_ori = state["param_rms"]
|
batch_rms_orig = state["param_rms"]
|
||||||
batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim)))
|
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
||||||
|
dim=list(range(1, batch_grad.ndim))
|
||||||
|
)
|
||||||
|
|
||||||
batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim)))
|
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||||
|
):
|
||||||
|
|
||||||
for name, sumsq_ori, sumsq_cur in zip(
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
batch_param_names, batch_sumsq_ori, batch_sumsq_cur):
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
proportion_ori = sumsq_ori / tot_sumsq
|
assert torch.isclose(
|
||||||
proportion_cur = sumsq_cur / tot_sumsq
|
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||||
|
torch.tensor(1.0),
|
||||||
all_sumsq_ori[name] = (proportion_ori, sumsq_ori)
|
)
|
||||||
all_sumsq_cur[name] = (proportion_cur, sumsq_cur)
|
sorted_by_proportion = {
|
||||||
|
k: v
|
||||||
for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)):
|
for k, v in sorted(
|
||||||
sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)}
|
all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
|
||||||
|
)
|
||||||
|
}
|
||||||
dominant_param_name = next(iter(sorted_by_proportion))
|
dominant_param_name = next(iter(sorted_by_proportion))
|
||||||
dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name]
|
(
|
||||||
logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion} {dominant_sumsq} {tot_sumsq}")
|
dominant_proportion,
|
||||||
|
dominant_sumsq,
|
||||||
|
dominant_rms,
|
||||||
|
dominant_grad,
|
||||||
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
|
logging.info(
|
||||||
|
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||||
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
|
f"={dominant_sumsq:.3e},"
|
||||||
|
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||||
|
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _step_one_batch(
|
def _step_one_batch(
|
||||||
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
||||||
|
@ -368,13 +368,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--show-dominant-parameters",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to show dominant parameters.",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -998,8 +991,7 @@ def run(rank, world_size, args):
|
|||||||
parameters_names = []
|
parameters_names = []
|
||||||
parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
|
parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
|
||||||
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
|
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
|
||||||
clipping_scale=2.0, parameters_names=parameters_names,
|
clipping_scale=2.0, parameters_names=parameters_names)
|
||||||
show_dominant_parameters=params.show_dominant_parameters)
|
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user