Make sub-module dropped out independently.

This commit is contained in:
Daniel Povey 2022-11-09 14:15:56 +08:00
parent 423f9e3026
commit 3ff3f440ee

View File

@ -485,35 +485,30 @@ class ZipformerEncoderLayer(nn.Module):
# dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate()
# multi-headed self-attention module
# TODO: make the various attention-using models be dropped
# out independently.
use_self_attn = (random.random() > dynamic_dropout)
if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
if torch.jit.is_scripting() or use_self_attn:
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.self_attn1(
src, attn_weights)
# convolution module
if torch.jit.is_scripting() or use_self_attn:
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.nonlin_attention_module(src,
attn_weights[0:1])
src = src + self.feed_forward2(src)
# pooling module
if torch.jit.is_scripting() or use_self_attn:
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.attention_squeeze1(src, attn_weights[1:2])
if torch.jit.is_scripting() or use_self_attn:
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.self_attn2(
src, attn_weights)
@ -523,10 +518,9 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward3(src)
# pooling module
if torch.jit.is_scripting() or use_self_attn:
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.attention_squeeze2(src, attn_weights[2:3])
src = self.norm_final(self.balancer(src))
delta = src - src_orig