diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 4e02dd382..5069b78e8 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -466,8 +466,10 @@ def average_state_dict( uniqued_names = list(uniqued.values()) for k in uniqued_names: - state_dict_1[k] *= weight_1 - state_dict_1[k] += ( - state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 - ) - state_dict_1[k] *= scaling_factor + v = state_dict_1[k] + if torch.is_floating_point(v): + v *= weight_1 + v += ( + state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + ) + v *= scaling_factor diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 071d27112..9d2ea46c9 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -137,6 +137,8 @@ class TensorDiagnostic(object): x = x[0] if not isinstance(x, Tensor): return + if x.numel() == 0: # for empty tensor + return x = x.detach().clone() if x.ndim == 0: x = x.unsqueeze(0) @@ -185,6 +187,9 @@ class TensorDiagnostic(object): def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" + if self.stats is None: + print(f"Warning: the stats of {self.name} is None.") + return for dim, this_dim_stats in enumerate(self.stats): for stats_type, stats_list in this_dim_stats.items(): # stats_type could be "rms", "value", "abs", "eigs", "positive".