Implement a form of dropout for squeeze_weights, dropout-to-constant.

This commit is contained in:
Daniel Povey 2022-11-17 10:54:29 +08:00
parent 99cd9f5788
commit fe51eea397

View File

@ -366,6 +366,7 @@ class ZipformerEncoderLayer(nn.Module):
# to work correctly.
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
squeeze_const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
bypass_max: FloatLike = 1.0,
) -> None:
@ -380,7 +381,7 @@ class ZipformerEncoderLayer(nn.Module):
# ever becoming zero.
self.bypass_min = copy.deepcopy(bypass_min)
self.bypass_max = copy.deepcopy(bypass_max)
self.squeeze_const_attention_rate = copy.deepcopy(squeeze_const_attention_rate)
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
@ -478,6 +479,17 @@ class ZipformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask,
)
squeeze_weights = attn_weights[1:2]
if random.random() < float(self.squeeze_const_attention_rate):
# this form of dropout makes the attention-weights used for the
# squeeze-excite modules constant wherever they are not masked. The intention
# is to encourage these modules to do something similar to an averaging-over-time
# operation.
squeeze_weights = (squeeze_weights > 0.0).to(squeeze_weights.dtype)
# make sure they sum to 1 over the last axis.
squeeze_weights = squeeze_weights * (1.0 / squeeze_weights.sum(dim=-1, keepdim=True))
if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src,
attn_weights[0:1])
@ -487,7 +499,7 @@ class ZipformerEncoderLayer(nn.Module):
# pooling module
if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze1(src, attn_weights[1:2])
src = src + self.attention_squeeze1(src, squeeze_weights)
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn(
@ -500,8 +512,7 @@ class ZipformerEncoderLayer(nn.Module):
# pooling module
if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze2(src, attn_weights[2:3])
src = src + self.attention_squeeze2(src, squeeze_weights)
src = self.norm_final(self.balancer(src))