change current abbreviation

This commit is contained in:
Guo Liyong 2022-12-21 16:35:38 +08:00
parent 9e4ce825ac
commit b011448a28

View File

@ -132,7 +132,7 @@ class ScaledAdam(BatchedOptimizer):
Args: Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses) params: The parameters or param_groups to optimize (like other Optimizer subclasses)
Unlike common optimizers, which accepts model.parameters() or groups of parameters(), Unlike common optimizers, which accept model.parameters() or groups of parameters(),
this optimizer could accept model.named_parameters() or groups of named_parameters(). this optimizer could accept model.named_parameters() or groups of named_parameters().
See comments of function _get_names_of_parameters for its 4 possible cases. See comments of function _get_names_of_parameters for its 4 possible cases.
lr: The learning rate. We will typically use a learning rate schedule that starts lr: The learning rate. We will typically use a learning rate schedule that starts
@ -259,7 +259,7 @@ class ScaledAdam(BatchedOptimizer):
# p is short for param. # p is short for param.
# np is short for named_param. # np is short for named_param.
# p_or_np is short for param_or_named_param. # p_or_np is short for param_or_named_param.
# curt is short for current. # cur is short for current.
# group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
# groups is a List[group] # groups is a List[group]
@ -277,8 +277,8 @@ class ScaledAdam(BatchedOptimizer):
if not isinstance(iterable_or_groups[0], dict): if not isinstance(iterable_or_groups[0], dict):
# case 1 or case 3, # case 1 or case 3,
# the input is an iterable of parameter or named parameter. # the input is an iterable of parameter or named parameter.
param_iterable_curt_group = [] param_iterable_cur_group = []
param_names_curt_group = [] param_names_cur_group = []
for p_or_np in iterable_or_groups: for p_or_np in iterable_or_groups:
if isinstance(p_or_np, tuple): if isinstance(p_or_np, tuple):
# case 3 # case 3
@ -290,17 +290,17 @@ class ScaledAdam(BatchedOptimizer):
# Assign a dummy name as a placeholder # Assign a dummy name as a placeholder
name = "foo" name = "foo"
self.show_dominant_parameters = False self.show_dominant_parameters = False
param_iterable_curt_group.append(param) param_iterable_cur_group.append(param)
param_names_curt_group.append(name) param_names_cur_group.append(name)
param_groups.append({"params": param_iterable_curt_group}) param_groups.append({"params": param_iterable_cur_group})
param_groups_names.append(param_names_curt_group) param_groups_names.append(param_names_cur_group)
else: else:
# case 2 or case 4 # case 2 or case 4
# the input is groups of parameter or named parameter. # the input is groups of parameter or named parameter.
for p_or_np_curt_group in iterable_or_groups: for p_or_np_cur_group in iterable_or_groups:
param_iterable_curt_group = [] param_iterable_cur_group = []
param_names_curt_group = [] param_names_cur_group = []
p_or_np_iterable = p_or_np_curt_group["params"] p_or_np_iterable = p_or_np_cur_group["params"]
for p_or_np in p_or_np_iterable: for p_or_np in p_or_np_iterable:
if isinstance(p_or_np, tuple): if isinstance(p_or_np, tuple):
# case 4 # case 4
@ -312,8 +312,8 @@ class ScaledAdam(BatchedOptimizer):
# Assign a dummy name as a placeholder # Assign a dummy name as a placeholder
name = "foo" name = "foo"
self.show_dominant_parameters = False self.show_dominant_parameters = False
param_iterable_curt_group.append(param) param_iterable_cur_group.append(param)
param_names_curt_group.append(name) param_names_cur_group.append(name)
# The original `params` filed contains named_parameters. # The original `params` filed contains named_parameters.
# After following assignment, # After following assignment,
@ -321,9 +321,9 @@ class ScaledAdam(BatchedOptimizer):
# and other fileds, if exist, are still original values. # and other fileds, if exist, are still original values.
# So param_groups could be used to initialize # So param_groups could be used to initialize
# an underlying torch.Optimizer later. # an underlying torch.Optimizer later.
p_or_np_curt_group["params"] = param_iterable_curt_group p_or_np_cur_group["params"] = param_iterable_cur_group
param_groups.append(p_or_np_curt_group) param_groups.append(p_or_np_cur_group)
param_groups_names.append(param_names_curt_group) param_groups_names.append(param_names_cur_group)
return param_groups, param_groups_names return param_groups, param_groups_names
def __setstate__(self, state): def __setstate__(self, state):