Remove dynamic weights in SimpleCombine

This commit is contained in:
Daniel Povey 2022-10-31 19:22:01 +08:00
parent b091ae5482
commit b7876baed6

View File

@ -825,9 +825,7 @@ class SimpleCombiner(torch.nn.Module):
min_weight: Tuple[float] = (0., 0.)):
super(SimpleCombiner, self).__init__()
assert dim2 >= dim1
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
self.bias = nn.Parameter(torch.zeros(()))
self.weight1 = nn.Parameter(torch.zeros(()))
self.min_weight = min_weight
def forward(self,
@ -843,25 +841,15 @@ class SimpleCombiner(torch.nn.Module):
dim1 = src1.shape[-1]
dim2 = src2.shape[-1]
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
logit = (weight1 + weight2 + self.bias)
if self.training and random.random() < 0.1:
logit = penalize_abs_values_gt(logit,
limit=25.0,
penalty=1.0e-04)
# `weight` will be the wight on src1.
weight = logit.sigmoid()
if self.training and self.min_weight != (0., 0.) and random.random() < 0.25:
weight = weight.clamp(min=self.min_weight[0],
max=1.0-self.min_weight[1])
weight1 = self.weight1
if self.training and random.random() < 0.25 and self.min_weight != (0., 0.):
weight1 = weight1.clamp(min=self.min_weight[0],
max=1.0-self.min_weight[1])
src1 = src1 * weight
src2 = src2 * (1.0 - weight)
src1 = src1 * weight1
src2 = src2 * (1.0 - weight1)
src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1]