Add bias in weight module

This commit is contained in:
Daniel Povey 2022-10-31 17:10:28 +08:00
parent 5e51534fbc
commit b091ae5482

View File

@ -827,6 +827,7 @@ class SimpleCombiner(torch.nn.Module):
assert dim2 >= dim1
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
self.bias = nn.Parameter(torch.zeros(()))
self.min_weight = min_weight
def forward(self,
@ -844,7 +845,7 @@ class SimpleCombiner(torch.nn.Module):
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
logit = (weight1 + weight2)
logit = (weight1 + weight2 + self.bias)
if self.training and random.random() < 0.1:
logit = penalize_abs_values_gt(logit,