mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove dynamic weights in SimpleCombine
This commit is contained in:
parent
b091ae5482
commit
b7876baed6
@ -825,9 +825,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
min_weight: Tuple[float] = (0., 0.)):
|
min_weight: Tuple[float] = (0., 0.)):
|
||||||
super(SimpleCombiner, self).__init__()
|
super(SimpleCombiner, self).__init__()
|
||||||
assert dim2 >= dim1
|
assert dim2 >= dim1
|
||||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
self.weight1 = nn.Parameter(torch.zeros(()))
|
||||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
|
||||||
self.bias = nn.Parameter(torch.zeros(()))
|
|
||||||
self.min_weight = min_weight
|
self.min_weight = min_weight
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -843,25 +841,15 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
dim1 = src1.shape[-1]
|
dim1 = src1.shape[-1]
|
||||||
dim2 = src2.shape[-1]
|
dim2 = src2.shape[-1]
|
||||||
|
|
||||||
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
|
||||||
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
|
||||||
logit = (weight1 + weight2 + self.bias)
|
|
||||||
|
|
||||||
if self.training and random.random() < 0.1:
|
weight1 = self.weight1
|
||||||
logit = penalize_abs_values_gt(logit,
|
if self.training and random.random() < 0.25 and self.min_weight != (0., 0.):
|
||||||
limit=25.0,
|
weight1 = weight1.clamp(min=self.min_weight[0],
|
||||||
penalty=1.0e-04)
|
max=1.0-self.min_weight[1])
|
||||||
|
|
||||||
# `weight` will be the wight on src1.
|
|
||||||
weight = logit.sigmoid()
|
|
||||||
|
|
||||||
if self.training and self.min_weight != (0., 0.) and random.random() < 0.25:
|
|
||||||
weight = weight.clamp(min=self.min_weight[0],
|
|
||||||
max=1.0-self.min_weight[1])
|
|
||||||
|
|
||||||
|
|
||||||
src1 = src1 * weight
|
src1 = src1 * weight1
|
||||||
src2 = src2 * (1.0 - weight)
|
src2 = src2 * (1.0 - weight1)
|
||||||
|
|
||||||
src1_dim = src1.shape[-1]
|
src1_dim = src1.shape[-1]
|
||||||
src2_dim = src2.shape[-1]
|
src2_dim = src2.shape[-1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user