Revert "Make sub-module dropped out independently."

This reverts commit 3ff3f440ee6d2a367cc3cc45e40f8eb69d122861.
This commit is contained in:
Daniel Povey 2022-11-11 21:36:36 +08:00
parent 742bcaa340
commit f7aff4f507

View File

@ -487,30 +487,35 @@ class ZipformerEncoderLayer(nn.Module):
# dropout rate for submodules that interact with time. # dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate() dynamic_dropout = self.get_dynamic_dropout_rate()
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # multi-headed self-attention module
attn_weights = self.self_attn_weights( # TODO: make the various attention-using models be dropped
src, # out independently.
pos_emb=pos_emb, use_self_attn = (random.random() > dynamic_dropout)
attn_mask=src_mask, if torch.jit.is_scripting() or use_self_attn:
key_padding_mask=src_key_padding_mask, # attn_weights: (num_heads, batch_size, seq_len, seq_len)
) attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn1( src = src + self.self_attn1(
src, attn_weights) src, attn_weights)
# convolution module # convolution module
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src, src = src + self.nonlin_attention_module(src,
attn_weights[0:1]) attn_weights[0:1])
src = src + self.feed_forward2(src) src = src + self.feed_forward2(src)
# pooling module # pooling module
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze1(src, attn_weights[1:2]) src = src + self.attention_squeeze1(src, attn_weights[1:2])
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn2( src = src + self.self_attn2(
src, attn_weights) src, attn_weights)
@ -520,9 +525,10 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward3(src) src = src + self.feed_forward3(src)
# pooling module # pooling module
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze2(src, attn_weights[2:3]) src = src + self.attention_squeeze2(src, attn_weights[2:3])
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
delta = src - src_orig delta = src - src_orig