Add out_balancer for attention_squeeze, similar to nonlin_attention_module.

This commit is contained in:
Daniel Povey 2022-12-12 23:29:42 +08:00
parent f4ff6188d9
commit 7920fa7726

View File

@ -1403,6 +1403,13 @@ class AttentionSqueeze(nn.Module):
prob=_aux_grad_prob_out(), prob=_aux_grad_prob_out(),
bias=False, initial_scale=0.05) bias=False, initial_scale=0.05)
self.out_balancer = ActivationBalancer(
channels, channel_dim=-1,
min_positive=0.4, max_positive=0.5,
min_abs=ScheduledFloat((0.0, 0.002), (8000.0, 0.02), (20000.0, 0.01)),
)
def forward(self, def forward(self,
x: Tensor, x: Tensor,
attn_weights: Tensor): attn_weights: Tensor):