mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert "Make sub-module dropped out independently."
This reverts commit 3ff3f440ee6d2a367cc3cc45e40f8eb69d122861.
This commit is contained in:
parent
742bcaa340
commit
f7aff4f507
@ -487,6 +487,11 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# dropout rate for submodules that interact with time.
|
# dropout rate for submodules that interact with time.
|
||||||
dynamic_dropout = self.get_dynamic_dropout_rate()
|
dynamic_dropout = self.get_dynamic_dropout_rate()
|
||||||
|
|
||||||
|
# multi-headed self-attention module
|
||||||
|
# TODO: make the various attention-using models be dropped
|
||||||
|
# out independently.
|
||||||
|
use_self_attn = (random.random() > dynamic_dropout)
|
||||||
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
attn_weights = self.self_attn_weights(
|
attn_weights = self.self_attn_weights(
|
||||||
src,
|
src,
|
||||||
@ -495,22 +500,22 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.self_attn1(
|
src = src + self.self_attn1(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.nonlin_attention_module(src,
|
src = src + self.nonlin_attention_module(src,
|
||||||
attn_weights[0:1])
|
attn_weights[0:1])
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.attention_squeeze1(src, attn_weights[1:2])
|
src = src + self.attention_squeeze1(src, attn_weights[1:2])
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.self_attn2(
|
src = src + self.self_attn2(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
@ -520,9 +525,10 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward3(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.attention_squeeze2(src, attn_weights[2:3])
|
src = src + self.attention_squeeze2(src, attn_weights[2:3])
|
||||||
|
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
delta = src - src_orig
|
delta = src - src_orig
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user