mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement a form of dropout for squeeze_weights, dropout-to-constant.
This commit is contained in:
parent
99cd9f5788
commit
fe51eea397
@ -366,6 +366,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# to work correctly.
|
||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
||||
squeeze_const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||
bypass_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
@ -380,7 +381,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# ever becoming zero.
|
||||
self.bypass_min = copy.deepcopy(bypass_min)
|
||||
self.bypass_max = copy.deepcopy(bypass_max)
|
||||
|
||||
self.squeeze_const_attention_rate = copy.deepcopy(squeeze_const_attention_rate)
|
||||
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||
@ -478,6 +479,17 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
|
||||
squeeze_weights = attn_weights[1:2]
|
||||
if random.random() < float(self.squeeze_const_attention_rate):
|
||||
# this form of dropout makes the attention-weights used for the
|
||||
# squeeze-excite modules constant wherever they are not masked. The intention
|
||||
# is to encourage these modules to do something similar to an averaging-over-time
|
||||
# operation.
|
||||
squeeze_weights = (squeeze_weights > 0.0).to(squeeze_weights.dtype)
|
||||
# make sure they sum to 1 over the last axis.
|
||||
squeeze_weights = squeeze_weights * (1.0 / squeeze_weights.sum(dim=-1, keepdim=True))
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
attn_weights[0:1])
|
||||
@ -487,7 +499,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.attention_squeeze1(src, attn_weights[1:2])
|
||||
src = src + self.attention_squeeze1(src, squeeze_weights)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn(
|
||||
@ -500,8 +512,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.attention_squeeze2(src, attn_weights[2:3])
|
||||
|
||||
src = src + self.attention_squeeze2(src, squeeze_weights)
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user