mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use half the dimension in AttentionSqueeze.
This commit is contained in:
parent
6e598cb18d
commit
d29e3d89e5
@ -440,7 +440,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
cnn_module_kernel)
|
cnn_module_kernel)
|
||||||
|
|
||||||
|
|
||||||
self.attention_squeeze = AttentionSqueeze(embed_dim)
|
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(embed_dim)
|
self.norm_final = BasicNorm(embed_dim)
|
||||||
|
|
||||||
@ -1323,11 +1323,12 @@ class AttentionSqueeze(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
bottleneck_dim: int = 16):
|
bottleneck_dim: int = 16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bottleneck_dim = bottleneck_dim
|
self.bottleneck_dim = bottleneck_dim
|
||||||
|
|
||||||
self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim,
|
self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
||||||
|
|
||||||
@ -1355,13 +1356,13 @@ class AttentionSqueeze(nn.Module):
|
|||||||
# Make them run with very low probability, since only a small
|
# Make them run with very low probability, since only a small
|
||||||
# application of these balancers should be enough to stop such "drift".
|
# application of these balancers should be enough to stop such "drift".
|
||||||
self.scale_balancer = ActivationBalancer(
|
self.scale_balancer = ActivationBalancer(
|
||||||
embed_dim, channel_dim=-1,
|
hidden_dim, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
min_prob=0.05,
|
min_prob=0.05,
|
||||||
)
|
)
|
||||||
self.activation_balancer = ActivationBalancer(
|
self.activation_balancer = ActivationBalancer(
|
||||||
embed_dim, channel_dim=-1,
|
hidden_dim, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
min_prob=0.05,
|
min_prob=0.05,
|
||||||
@ -1372,10 +1373,11 @@ class AttentionSqueeze(nn.Module):
|
|||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
|
||||||
|
|
||||||
self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim,
|
self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
aux_grad_scale=_aux_grad_scale(),
|
||||||
|
prob=_aux_grad_prob_out(),
|
||||||
bias=False, initial_scale=0.05)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user