Remove dynamic weights in SimpleCombine

This commit is contained in:
Daniel Povey 2022-10-31 19:22:01 +08:00
parent b091ae5482
commit b7876baed6

View File

@ -825,9 +825,7 @@ class SimpleCombiner(torch.nn.Module):
min_weight: Tuple[float] = (0., 0.)): min_weight: Tuple[float] = (0., 0.)):
super(SimpleCombiner, self).__init__() super(SimpleCombiner, self).__init__()
assert dim2 >= dim1 assert dim2 >= dim1
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) self.weight1 = nn.Parameter(torch.zeros(()))
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
self.bias = nn.Parameter(torch.zeros(()))
self.min_weight = min_weight self.min_weight = min_weight
def forward(self, def forward(self,
@ -843,25 +841,15 @@ class SimpleCombiner(torch.nn.Module):
dim1 = src1.shape[-1] dim1 = src1.shape[-1]
dim2 = src2.shape[-1] dim2 = src2.shape[-1]
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
logit = (weight1 + weight2 + self.bias)
if self.training and random.random() < 0.1: weight1 = self.weight1
logit = penalize_abs_values_gt(logit, if self.training and random.random() < 0.25 and self.min_weight != (0., 0.):
limit=25.0, weight1 = weight1.clamp(min=self.min_weight[0],
penalty=1.0e-04) max=1.0-self.min_weight[1])
# `weight` will be the wight on src1.
weight = logit.sigmoid()
if self.training and self.min_weight != (0., 0.) and random.random() < 0.25:
weight = weight.clamp(min=self.min_weight[0],
max=1.0-self.min_weight[1])
src1 = src1 * weight src1 = src1 * weight1
src2 = src2 * (1.0 - weight) src2 = src2 * (1.0 - weight1)
src1_dim = src1.shape[-1] src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1] src2_dim = src2.shape[-1]