Also apply limit on logit in SimpleCombiner

This commit is contained in:
Daniel Povey 2022-10-21 21:31:52 +08:00
parent bdbd2cfce6
commit e5fe3de17e

View File

@ -722,7 +722,13 @@ class SimpleCombiner(torch.nn.Module):
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
weight = (weight1 + weight2).sigmoid()
logit = (weight1 + weight2)
if self.training and random.random() < 0.1:
logit = penalize_abs_values_gt(logit,
limit=10.0,
penalty=1.0e-04)
weight = logit.sigmoid()
src2_part1 = src2[...,:dim1]
part1 = src1 * weight + src2_part1 * (1.0 - weight)
@ -1569,9 +1575,10 @@ class AttentionCombine(nn.Module):
scores = scores.masked_fill(mask, float('-inf'))
scores = penalize_abs_values_gt(scores,
limit=10.0,
penalty=1.0e-04)
if self.training and random.random() < 0.1:
scores = penalize_abs_values_gt(scores,
limit=10.0,
penalty=1.0e-04)
weights = scores.softmax(dim=1)