Update checkpoint.py to deal with int params

This commit is contained in:
Daniel Povey 2022-10-07 17:06:38 +08:00
parent ebf8aa129d
commit 28e5f46854
2 changed files with 12 additions and 5 deletions

View File

@ -466,8 +466,10 @@ def average_state_dict(
uniqued_names = list(uniqued.values()) uniqued_names = list(uniqued.values())
for k in uniqued_names: for k in uniqued_names:
state_dict_1[k] *= weight_1 v = state_dict_1[k]
state_dict_1[k] += ( if torch.is_floating_point(v):
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 v *= weight_1
) v += (
state_dict_1[k] *= scaling_factor state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
)
v *= scaling_factor

View File

@ -137,6 +137,8 @@ class TensorDiagnostic(object):
x = x[0] x = x[0]
if not isinstance(x, Tensor): if not isinstance(x, Tensor):
return return
if x.numel() == 0: # for empty tensor
return
x = x.detach().clone() x = x.detach().clone()
if x.ndim == 0: if x.ndim == 0:
x = x.unsqueeze(0) x = x.unsqueeze(0)
@ -185,6 +187,9 @@ class TensorDiagnostic(object):
def print_diagnostics(self): def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor.""" """Print diagnostics for each dimension of the tensor."""
if self.stats is None:
print(f"Warning: the stats of {self.name} is None.")
return
for dim, this_dim_stats in enumerate(self.stats): for dim, this_dim_stats in enumerate(self.stats):
for stats_type, stats_list in this_dim_stats.items(): for stats_type, stats_list in this_dim_stats.items():
# stats_type could be "rms", "value", "abs", "eigs", "positive". # stats_type could be "rms", "value", "abs", "eigs", "positive".