Revert "Make sub-module dropped out independently."

This reverts commit 3ff3f440ee6d2a367cc3cc45e40f8eb69d122861.
This commit is contained in:
Daniel Povey 2022-11-11 21:36:36 +08:00
parent 742bcaa340
commit f7aff4f507

View File

@ -487,6 +487,11 @@ class ZipformerEncoderLayer(nn.Module):
# dropout rate for submodules that interact with time. # dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate() dynamic_dropout = self.get_dynamic_dropout_rate()
# multi-headed self-attention module
# TODO: make the various attention-using models be dropped
# out independently.
use_self_attn = (random.random() > dynamic_dropout)
if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights( attn_weights = self.self_attn_weights(
src, src,
@ -495,22 +500,22 @@ class ZipformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
) )
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn1( src = src + self.self_attn1(
src, attn_weights) src, attn_weights)
# convolution module # convolution module
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src, src = src + self.nonlin_attention_module(src,
attn_weights[0:1]) attn_weights[0:1])
src = src + self.feed_forward2(src) src = src + self.feed_forward2(src)
# pooling module # pooling module
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze1(src, attn_weights[1:2]) src = src + self.attention_squeeze1(src, attn_weights[1:2])
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn2( src = src + self.self_attn2(
src, attn_weights) src, attn_weights)
@ -520,9 +525,10 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward3(src) src = src + self.feed_forward3(src)
# pooling module # pooling module
if torch.jit.is_scripting() or random.random() > dynamic_dropout: if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze2(src, attn_weights[2:3]) src = src + self.attention_squeeze2(src, attn_weights[2:3])
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
delta = src - src_orig delta = src - src_orig