mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add bias in weight module
This commit is contained in:
parent
5e51534fbc
commit
b091ae5482
@ -827,6 +827,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
assert dim2 >= dim1
|
||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||
self.bias = nn.Parameter(torch.zeros(()))
|
||||
self.min_weight = min_weight
|
||||
|
||||
def forward(self,
|
||||
@ -844,7 +845,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
||||
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
||||
logit = (weight1 + weight2)
|
||||
logit = (weight1 + weight2 + self.bias)
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
logit = penalize_abs_values_gt(logit,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user