diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1c181b1ac..2bcd7202d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -853,7 +853,8 @@ class SimpleCombiner(torch.nn.Module): min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1 - self.weight1 = nn.Parameter(torch.ones(dim2) * min_weight[0]) + initial_weight1 = 0.1 + self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1)) self.min_weight = min_weight def forward(self, diff --git a/icefall/utils.py b/icefall/utils.py index ad079222e..09523fdf2 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -530,9 +530,10 @@ class MetricsTracker(collections.defaultdict): def __add__(self, other: "MetricsTracker") -> "MetricsTracker": ans = MetricsTracker() for k, v in self.items(): - ans[k] = v + ans[k] = v if v - v == 0.0 else 0.0 # discard infinities. for k, v in other.items(): - ans[k] = ans[k] + v + if v - v == 0: # discard infinities. + ans[k] = ans[k] + v return ans def __mul__(self, alpha: float) -> "MetricsTracker":