mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use half the dimension in AttentionSqueeze.
This commit is contained in:
parent
6e598cb18d
commit
d29e3d89e5
@ -440,7 +440,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
cnn_module_kernel)
|
||||
|
||||
|
||||
self.attention_squeeze = AttentionSqueeze(embed_dim)
|
||||
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||
|
||||
self.norm_final = BasicNorm(embed_dim)
|
||||
|
||||
@ -1323,11 +1323,12 @@ class AttentionSqueeze(nn.Module):
|
||||
"""
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
bottleneck_dim: int = 16):
|
||||
super().__init__()
|
||||
self.bottleneck_dim = bottleneck_dim
|
||||
|
||||
self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim,
|
||||
self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
|
||||
bias=False,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
||||
|
||||
@ -1355,13 +1356,13 @@ class AttentionSqueeze(nn.Module):
|
||||
# Make them run with very low probability, since only a small
|
||||
# application of these balancers should be enough to stop such "drift".
|
||||
self.scale_balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
)
|
||||
self.activation_balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
@ -1372,10 +1373,11 @@ class AttentionSqueeze(nn.Module):
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
|
||||
|
||||
self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
||||
self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
|
||||
aux_grad_scale=_aux_grad_scale(),
|
||||
prob=_aux_grad_prob_out(),
|
||||
bias=False, initial_scale=0.05)
|
||||
|
||||
def forward(self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user