Remove one attention_squeeze module; halve dimension in NonlinAttention module; put schedule on balancer of ConvolutionModule
This commit is contained in:
parent
a96b92fb54
commit
9ce99b150d
@ -421,8 +421,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
cnn_module_kernel)
|
cnn_module_kernel)
|
||||||
|
|
||||||
|
|
||||||
self.attention_squeeze1 = AttentionSqueeze(embed_dim)
|
self.attention_squeeze = AttentionSqueeze(embed_dim)
|
||||||
self.attention_squeeze2 = AttentionSqueeze(embed_dim)
|
|
||||||
|
|
||||||
self.norm_final = BasicNorm(embed_dim)
|
self.norm_final = BasicNorm(embed_dim)
|
||||||
|
|
||||||
@ -510,7 +509,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.attention_squeeze1(src, first_attn_weights[1:2])
|
src = src + self.attention_squeeze(src, first_attn_weights[1:2])
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.self_attn(
|
src = src + self.self_attn(
|
||||||
@ -521,9 +520,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
# pooling module
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
|
||||||
src = src + self.attention_squeeze2(src, first_attn_weights[2:3])
|
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
@ -1432,11 +1428,11 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
|
self.in_proj = nn.Linear(channels, channels, bias=True)
|
||||||
|
|
||||||
# balancer that goes before the sigmoid.
|
# balancer that goes before the sigmoid.
|
||||||
self.balancer = ActivationBalancer(
|
self.balancer = ActivationBalancer(
|
||||||
channels, channel_dim=-1,
|
channels // 2, channel_dim=-1,
|
||||||
min_positive=0.05, max_positive=1.0,
|
min_positive=0.05, max_positive=1.0,
|
||||||
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
|
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
|
||||||
(4000.0, 10.0),
|
(4000.0, 10.0),
|
||||||
@ -1445,7 +1441,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
self.activation = Identity() # for diagnostics.
|
self.activation = Identity() # for diagnostics.
|
||||||
self.out_proj = ScaledLinear(channels, channels,
|
self.out_proj = ScaledLinear(channels // 2, channels,
|
||||||
bias=True,
|
bias=True,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
@ -1532,7 +1528,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
# the correct range.
|
# the correct range.
|
||||||
self.deriv_balancer1 = ActivationBalancer(
|
self.deriv_balancer1 = ActivationBalancer(
|
||||||
2 * channels, channel_dim=-1,
|
2 * channels, channel_dim=-1,
|
||||||
max_abs=10.0, min_positive=0.05, max_positive=1.0
|
max_abs=ScheduledFloat((0.0, 2.0),
|
||||||
|
(4000.0, 10.0),
|
||||||
|
default=1.0),
|
||||||
|
min_positive=0.05, max_positive=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user