Restore feedforward3 module

This commit is contained in:
Daniel Povey 2022-11-03 12:41:19 +08:00
parent 0379ab57a2
commit a27670d097

View File

@ -338,8 +338,12 @@ class ZipformerEncoderLayer(nn.Module):
dropout)
self.feed_forward2 = FeedforwardModule(d_model,
feedforward_dim,
dropout)
feedforward_dim,
dropout)
self.feed_forward3 = FeedforwardModule(d_model,
feedforward_dim,
dropout)
self.conv_module1 = ConvolutionModule(d_model,
cnn_module_kernel)
@ -446,13 +450,15 @@ class ZipformerEncoderLayer(nn.Module):
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward2(src)
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn.forward2(src, attn_weights)
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward2(src)
src = src + self.feed_forward3(src)
src = self.norm_final(self.balancer(src))