From 97a1dd40cf943508fcabc15b323b63752bc3803c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Nov 2022 13:46:14 +0800 Subject: [PATCH] Change initialization value of weight in SimpleCombine from 0.0 to 0.1; ignore infinities in MetricsTracker . --- .../ASR/pruned_transducer_stateless7/zipformer.py | 3 ++- icefall/utils.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1c181b1ac..2bcd7202d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -853,7 +853,8 @@ class SimpleCombiner(torch.nn.Module): min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1 - self.weight1 = nn.Parameter(torch.ones(dim2) * min_weight[0]) + initial_weight1 = 0.1 + self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1)) self.min_weight = min_weight def forward(self, diff --git a/icefall/utils.py b/icefall/utils.py index ad079222e..09523fdf2 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -530,9 +530,10 @@ class MetricsTracker(collections.defaultdict): def __add__(self, other: "MetricsTracker") -> "MetricsTracker": ans = MetricsTracker() for k, v in self.items(): - ans[k] = v + ans[k] = v if v - v == 0.0 else 0.0 # discard infinities. for k, v in other.items(): - ans[k] = ans[k] + v + if v - v == 0: # discard infinities. + ans[k] = ans[k] + v return ans def __mul__(self, alpha: float) -> "MetricsTracker":