mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update checkpoint.py to deal with int params
This commit is contained in:
parent
ebf8aa129d
commit
28e5f46854
@ -466,8 +466,10 @@ def average_state_dict(
|
|||||||
|
|
||||||
uniqued_names = list(uniqued.values())
|
uniqued_names = list(uniqued.values())
|
||||||
for k in uniqued_names:
|
for k in uniqued_names:
|
||||||
state_dict_1[k] *= weight_1
|
v = state_dict_1[k]
|
||||||
state_dict_1[k] += (
|
if torch.is_floating_point(v):
|
||||||
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
|
v *= weight_1
|
||||||
)
|
v += (
|
||||||
state_dict_1[k] *= scaling_factor
|
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
|
||||||
|
)
|
||||||
|
v *= scaling_factor
|
||||||
|
|||||||
@ -137,6 +137,8 @@ class TensorDiagnostic(object):
|
|||||||
x = x[0]
|
x = x[0]
|
||||||
if not isinstance(x, Tensor):
|
if not isinstance(x, Tensor):
|
||||||
return
|
return
|
||||||
|
if x.numel() == 0: # for empty tensor
|
||||||
|
return
|
||||||
x = x.detach().clone()
|
x = x.detach().clone()
|
||||||
if x.ndim == 0:
|
if x.ndim == 0:
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
@ -185,6 +187,9 @@ class TensorDiagnostic(object):
|
|||||||
|
|
||||||
def print_diagnostics(self):
|
def print_diagnostics(self):
|
||||||
"""Print diagnostics for each dimension of the tensor."""
|
"""Print diagnostics for each dimension of the tensor."""
|
||||||
|
if self.stats is None:
|
||||||
|
print(f"Warning: the stats of {self.name} is None.")
|
||||||
|
return
|
||||||
for dim, this_dim_stats in enumerate(self.stats):
|
for dim, this_dim_stats in enumerate(self.stats):
|
||||||
for stats_type, stats_list in this_dim_stats.items():
|
for stats_type, stats_list in this_dim_stats.items():
|
||||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user