mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Also apply limit on logit in SimpleCombiner
This commit is contained in:
parent
bdbd2cfce6
commit
e5fe3de17e
@ -722,7 +722,13 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
||||
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
||||
weight = (weight1 + weight2).sigmoid()
|
||||
logit = (weight1 + weight2)
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
logit = penalize_abs_values_gt(logit,
|
||||
limit=10.0,
|
||||
penalty=1.0e-04)
|
||||
weight = logit.sigmoid()
|
||||
|
||||
src2_part1 = src2[...,:dim1]
|
||||
part1 = src1 * weight + src2_part1 * (1.0 - weight)
|
||||
@ -1569,9 +1575,10 @@ class AttentionCombine(nn.Module):
|
||||
|
||||
scores = scores.masked_fill(mask, float('-inf'))
|
||||
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=10.0,
|
||||
penalty=1.0e-04)
|
||||
if self.training and random.random() < 0.1:
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=10.0,
|
||||
penalty=1.0e-04)
|
||||
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user