mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change initialization value of weight in SimpleCombine from 0.0 to 0.1; ignore infinities in MetricsTracker
.
This commit is contained in:
parent
a27670d097
commit
97a1dd40cf
@ -853,7 +853,8 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
min_weight: Tuple[float] = (0., 0.)):
|
min_weight: Tuple[float] = (0., 0.)):
|
||||||
super(SimpleCombiner, self).__init__()
|
super(SimpleCombiner, self).__init__()
|
||||||
assert dim2 >= dim1
|
assert dim2 >= dim1
|
||||||
self.weight1 = nn.Parameter(torch.ones(dim2) * min_weight[0])
|
initial_weight1 = 0.1
|
||||||
|
self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1))
|
||||||
self.min_weight = min_weight
|
self.min_weight = min_weight
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
@ -530,9 +530,10 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||||
ans = MetricsTracker()
|
ans = MetricsTracker()
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
ans[k] = v
|
ans[k] = v if v - v == 0.0 else 0.0 # discard infinities.
|
||||||
for k, v in other.items():
|
for k, v in other.items():
|
||||||
ans[k] = ans[k] + v
|
if v - v == 0: # discard infinities.
|
||||||
|
ans[k] = ans[k] + v
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def __mul__(self, alpha: float) -> "MetricsTracker":
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user