mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove dynamic weights in SimpleCombine
This commit is contained in:
parent
b091ae5482
commit
b7876baed6
@ -825,9 +825,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
min_weight: Tuple[float] = (0., 0.)):
|
||||
super(SimpleCombiner, self).__init__()
|
||||
assert dim2 >= dim1
|
||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||
self.bias = nn.Parameter(torch.zeros(()))
|
||||
self.weight1 = nn.Parameter(torch.zeros(()))
|
||||
self.min_weight = min_weight
|
||||
|
||||
def forward(self,
|
||||
@ -843,25 +841,15 @@ class SimpleCombiner(torch.nn.Module):
|
||||
dim1 = src1.shape[-1]
|
||||
dim2 = src2.shape[-1]
|
||||
|
||||
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
||||
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
||||
logit = (weight1 + weight2 + self.bias)
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
logit = penalize_abs_values_gt(logit,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04)
|
||||
|
||||
# `weight` will be the wight on src1.
|
||||
weight = logit.sigmoid()
|
||||
|
||||
if self.training and self.min_weight != (0., 0.) and random.random() < 0.25:
|
||||
weight = weight.clamp(min=self.min_weight[0],
|
||||
max=1.0-self.min_weight[1])
|
||||
weight1 = self.weight1
|
||||
if self.training and random.random() < 0.25 and self.min_weight != (0., 0.):
|
||||
weight1 = weight1.clamp(min=self.min_weight[0],
|
||||
max=1.0-self.min_weight[1])
|
||||
|
||||
|
||||
src1 = src1 * weight
|
||||
src2 = src2 * (1.0 - weight)
|
||||
src1 = src1 * weight1
|
||||
src2 = src2 * (1.0 - weight1)
|
||||
|
||||
src1_dim = src1.shape[-1]
|
||||
src2_dim = src2.shape[-1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user