mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add another balancer to ZipformerEncoderLayer, prior to output.
This commit is contained in:
parent
0c3530a6fd
commit
da0623aa7f
@ -452,20 +452,25 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(embed_dim)
|
self.norm = BasicNorm(embed_dim)
|
||||||
|
|
||||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
self.balancer1 = Balancer(
|
||||||
self.balancer = Balancer(
|
|
||||||
embed_dim, channel_dim=-1,
|
embed_dim, channel_dim=-1,
|
||||||
min_positive=0.45, max_positive=0.55,
|
min_positive=0.45, max_positive=0.55,
|
||||||
min_abs=1.0, max_abs=6.0,
|
min_abs=1.0, max_abs=4.0,
|
||||||
)
|
)
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
self.balancer2 = Balancer(
|
||||||
|
embed_dim, channel_dim=-1,
|
||||||
|
min_positive=0.45, max_positive=0.55,
|
||||||
|
min_abs=0.5, max_abs=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def remove_attention_weights(self):
|
def remove_attention_weights(self):
|
||||||
self.self_attn_weights = None
|
self.self_attn_weights = None
|
||||||
@ -571,12 +576,15 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
src = self.balancer(src)
|
src = self.balancer1(src)
|
||||||
src = self.norm_final(src)
|
src = self.norm(src)
|
||||||
|
|
||||||
bypass_scale = self.get_bypass_scale(src.shape[1])
|
bypass_scale = self.get_bypass_scale(src.shape[1])
|
||||||
src = src * bypass_scale + src_orig * (1.0 - bypass_scale)
|
# the next line equivalent to: src = src * bypass_scale + src_orig *
|
||||||
|
# (1.0 - bypass_scale), but more memory efficient for backprop.
|
||||||
|
src = src_orig + (src - src_orig) * bypass_scale
|
||||||
|
|
||||||
|
src = self.balancer2(src)
|
||||||
src = self.whiten(src)
|
src = self.whiten(src)
|
||||||
|
|
||||||
return src, attn_weights
|
return src, attn_weights
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user