Change initialization value of weight in SimpleCombine from 0.0 to 0.1; ignore infinities in MetricsTracker

.
This commit is contained in:
Daniel Povey 2022-11-03 13:46:14 +08:00
parent a27670d097
commit 97a1dd40cf
2 changed files with 5 additions and 3 deletions

View File

@ -853,7 +853,8 @@ class SimpleCombiner(torch.nn.Module):
min_weight: Tuple[float] = (0., 0.)): min_weight: Tuple[float] = (0., 0.)):
super(SimpleCombiner, self).__init__() super(SimpleCombiner, self).__init__()
assert dim2 >= dim1 assert dim2 >= dim1
self.weight1 = nn.Parameter(torch.ones(dim2) * min_weight[0]) initial_weight1 = 0.1
self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1))
self.min_weight = min_weight self.min_weight = min_weight
def forward(self, def forward(self,

View File

@ -530,8 +530,9 @@ class MetricsTracker(collections.defaultdict):
def __add__(self, other: "MetricsTracker") -> "MetricsTracker": def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker() ans = MetricsTracker()
for k, v in self.items(): for k, v in self.items():
ans[k] = v ans[k] = v if v - v == 0.0 else 0.0 # discard infinities.
for k, v in other.items(): for k, v in other.items():
if v - v == 0: # discard infinities.
ans[k] = ans[k] + v ans[k] = ans[k] + v
return ans return ans