Make sub-module dropped out independently.

This commit is contained in:
Daniel Povey 2022-11-09 14:15:56 +08:00
parent 423f9e3026
commit 3ff3f440ee

View File

@ -485,35 +485,30 @@ class ZipformerEncoderLayer(nn.Module):
# dropout rate for submodules that interact with time. # dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate() dynamic_dropout = self.get_dynamic_dropout_rate()
# multi-headed self-attention module # attn_weights: (num_heads, batch_size, seq_len, seq_len)
# TODO: make the various attention-using models be dropped attn_weights = self.self_attn_weights(
# out independently. src,
use_self_attn = (random.random() > dynamic_dropout) pos_emb=pos_emb,
if torch.jit.is_scripting() or use_self_attn: attn_mask=src_mask,
# attn_weights: (num_heads, batch_size, seq_len, seq_len) key_padding_mask=src_key_padding_mask,
attn_weights = self.self_attn_weights( )
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.self_attn1( src = src + self.self_attn1(
src, attn_weights) src, attn_weights)
# convolution module # convolution module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.nonlin_attention_module(src, src = src + self.nonlin_attention_module(src,
attn_weights[0:1]) attn_weights[0:1])
src = src + self.feed_forward2(src) src = src + self.feed_forward2(src)
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.attention_squeeze1(src, attn_weights[1:2]) src = src + self.attention_squeeze1(src, attn_weights[1:2])
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.self_attn2( src = src + self.self_attn2(
src, attn_weights) src, attn_weights)
@ -523,10 +518,9 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward3(src) src = src + self.feed_forward3(src)
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.attention_squeeze2(src, attn_weights[2:3]) src = src + self.attention_squeeze2(src, attn_weights[2:3])
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
delta = src - src_orig delta = src - src_orig