mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove one attention_squeeze module; halve dimension in NonlinAttention module; put schedule on balancer of ConvolutionModule
This commit is contained in:
parent
a96b92fb54
commit
9ce99b150d
@ -421,8 +421,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
cnn_module_kernel)
|
||||
|
||||
|
||||
self.attention_squeeze1 = AttentionSqueeze(embed_dim)
|
||||
self.attention_squeeze2 = AttentionSqueeze(embed_dim)
|
||||
self.attention_squeeze = AttentionSqueeze(embed_dim)
|
||||
|
||||
self.norm_final = BasicNorm(embed_dim)
|
||||
|
||||
@ -510,7 +509,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.attention_squeeze1(src, first_attn_weights[1:2])
|
||||
src = src + self.attention_squeeze(src, first_attn_weights[1:2])
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn(
|
||||
@ -521,9 +520,6 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.attention_squeeze2(src, first_attn_weights[2:3])
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
@ -1432,11 +1428,11 @@ class NonlinAttentionModule(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
|
||||
self.in_proj = nn.Linear(channels, channels, bias=True)
|
||||
|
||||
# balancer that goes before the sigmoid.
|
||||
self.balancer = ActivationBalancer(
|
||||
channels, channel_dim=-1,
|
||||
channels // 2, channel_dim=-1,
|
||||
min_positive=0.05, max_positive=1.0,
|
||||
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
|
||||
(4000.0, 10.0),
|
||||
@ -1445,7 +1441,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.activation = Identity() # for diagnostics.
|
||||
self.out_proj = ScaledLinear(channels, channels,
|
||||
self.out_proj = ScaledLinear(channels // 2, channels,
|
||||
bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
@ -1532,7 +1528,10 @@ class ConvolutionModule(nn.Module):
|
||||
# the correct range.
|
||||
self.deriv_balancer1 = ActivationBalancer(
|
||||
2 * channels, channel_dim=-1,
|
||||
max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
max_abs=ScheduledFloat((0.0, 2.0),
|
||||
(4000.0, 10.0),
|
||||
default=1.0),
|
||||
min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user