Use half the dimension in AttentionSqueeze.

This commit is contained in:
Daniel Povey 2022-12-07 18:14:06 +08:00
parent 6e598cb18d
commit d29e3d89e5

View File

@ -440,7 +440,7 @@ class ZipformerEncoderLayer(nn.Module):
cnn_module_kernel) cnn_module_kernel)
self.attention_squeeze = AttentionSqueeze(embed_dim) self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
self.norm_final = BasicNorm(embed_dim) self.norm_final = BasicNorm(embed_dim)
@ -1323,11 +1323,12 @@ class AttentionSqueeze(nn.Module):
""" """
def __init__(self, def __init__(self,
embed_dim: int, embed_dim: int,
hidden_dim: int,
bottleneck_dim: int = 16): bottleneck_dim: int = 16):
super().__init__() super().__init__()
self.bottleneck_dim = bottleneck_dim self.bottleneck_dim = bottleneck_dim
self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim, self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
bias=False, bias=False,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
@ -1355,13 +1356,13 @@ class AttentionSqueeze(nn.Module):
# Make them run with very low probability, since only a small # Make them run with very low probability, since only a small
# application of these balancers should be enough to stop such "drift". # application of these balancers should be enough to stop such "drift".
self.scale_balancer = ActivationBalancer( self.scale_balancer = ActivationBalancer(
embed_dim, channel_dim=-1, hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8, min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0, min_abs=0.2, max_abs=1.0,
min_prob=0.05, min_prob=0.05,
) )
self.activation_balancer = ActivationBalancer( self.activation_balancer = ActivationBalancer(
embed_dim, channel_dim=-1, hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8, min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0, min_abs=0.2, max_abs=1.0,
min_prob=0.05, min_prob=0.05,
@ -1372,10 +1373,11 @@ class AttentionSqueeze(nn.Module):
grad_scale=0.01) grad_scale=0.01)
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim, self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(), aux_grad_scale=_aux_grad_scale(),
prob=_aux_grad_prob_out(),
bias=False, initial_scale=0.05) bias=False, initial_scale=0.05)
def forward(self, def forward(self,