mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Restore feedforward3 module
This commit is contained in:
parent
0379ab57a2
commit
a27670d097
@ -338,8 +338,12 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
dropout)
|
||||
|
||||
self.feed_forward2 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
self.feed_forward3 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
self.conv_module1 = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
@ -446,13 +450,15 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn.forward2(src, attn_weights)
|
||||
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user