Add bias in weight module
This commit is contained in:
parent
5e51534fbc
commit
b091ae5482
@ -827,6 +827,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
assert dim2 >= dim1
|
assert dim2 >= dim1
|
||||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(()))
|
||||||
self.min_weight = min_weight
|
self.min_weight = min_weight
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -844,7 +845,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
|
|
||||||
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
||||||
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
||||||
logit = (weight1 + weight2)
|
logit = (weight1 + weight2 + self.bias)
|
||||||
|
|
||||||
if self.training and random.random() < 0.1:
|
if self.training and random.random() < 0.1:
|
||||||
logit = penalize_abs_values_gt(logit,
|
logit = penalize_abs_values_gt(logit,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user