Remove one attention_squeeze module; halve dimension in NonlinAttention module; put schedule on balancer of ConvolutionModule

This commit is contained in:
Daniel Povey 2022-11-26 19:42:33 +08:00
parent a96b92fb54
commit 9ce99b150d

View File

@ -421,8 +421,7 @@ class ZipformerEncoderLayer(nn.Module):
cnn_module_kernel)
self.attention_squeeze1 = AttentionSqueeze(embed_dim)
self.attention_squeeze2 = AttentionSqueeze(embed_dim)
self.attention_squeeze = AttentionSqueeze(embed_dim)
self.norm_final = BasicNorm(embed_dim)
@ -510,7 +509,7 @@ class ZipformerEncoderLayer(nn.Module):
# pooling module
if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze1(src, first_attn_weights[1:2])
src = src + self.attention_squeeze(src, first_attn_weights[1:2])
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn(
@ -521,9 +520,6 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward2(src)
# pooling module
if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze2(src, first_attn_weights[2:3])
src = self.norm_final(self.balancer(src))
@ -1432,11 +1428,11 @@ class NonlinAttentionModule(nn.Module):
) -> None:
super().__init__()
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
self.in_proj = nn.Linear(channels, channels, bias=True)
# balancer that goes before the sigmoid.
self.balancer = ActivationBalancer(
channels, channel_dim=-1,
channels // 2, channel_dim=-1,
min_positive=0.05, max_positive=1.0,
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
(4000.0, 10.0),
@ -1445,7 +1441,7 @@ class NonlinAttentionModule(nn.Module):
self.sigmoid = nn.Sigmoid()
self.activation = Identity() # for diagnostics.
self.out_proj = ScaledLinear(channels, channels,
self.out_proj = ScaledLinear(channels // 2, channels,
bias=True,
initial_scale=0.05)
@ -1532,7 +1528,10 @@ class ConvolutionModule(nn.Module):
# the correct range.
self.deriv_balancer1 = ActivationBalancer(
2 * channels, channel_dim=-1,
max_abs=10.0, min_positive=0.05, max_positive=1.0
max_abs=ScheduledFloat((0.0, 2.0),
(4000.0, 10.0),
default=1.0),
min_positive=0.05, max_positive=1.0
)
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.