Merge branch 'scaled_adam_exp268b' into scaled_adam_exp279

This commit is contained in:
Daniel Povey 2022-11-04 14:42:00 +08:00
commit 31d9bbfb3c

View File

@ -331,8 +331,6 @@ class ZipformerEncoderLayer(nn.Module):
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
)
self.squeeze_excite = ModifiedSEModule(d_model)
self.feed_forward1 = FeedforwardModule(d_model,
feedforward_dim,
dropout)
@ -351,6 +349,9 @@ class ZipformerEncoderLayer(nn.Module):
self.conv_module2 = ConvolutionModule(d_model,
cnn_module_kernel)
self.squeeze_excite = ModifiedSEModule(d_model)
self.norm_final = BasicNorm(d_model)
self.bypass_scale = nn.Parameter(torch.tensor(0.5))
@ -430,11 +431,6 @@ class ZipformerEncoderLayer(nn.Module):
# dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate()
# pooling module
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.squeeze_excite(src,
key_padding_mask=src_key_padding_mask)
# multi-headed self-attention module
use_self_attn = (random.random() > dynamic_dropout)
if torch.jit.is_scripting() or use_self_attn:
@ -460,6 +456,11 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward3(src)
# pooling module
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.squeeze_excite(src,
key_padding_mask=src_key_padding_mask)
src = self.norm_final(self.balancer(src))
delta = src - src_orig