mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce dropout rate to dynamic submodules of conformer.
This commit is contained in:
parent
3de8a5aef2
commit
12f17f550e
@ -383,6 +383,21 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
|
||||
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
|
||||
|
||||
def get_dynamic_dropout_rate(self):
|
||||
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
|
||||
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
|
||||
# at the beginning, by making the network focus on the feedforward modules.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return 0.0
|
||||
warmup_period = 2000.0
|
||||
initial_dropout_rate = 0.2
|
||||
final_dropout_rate = 0.0
|
||||
if self.batch_count > warmup_period:
|
||||
return final_dropout_rate
|
||||
else:
|
||||
return (initial_dropout_rate -
|
||||
(initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -412,28 +427,37 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# macaron style feed forward module
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# dropout rate for submodules that interact with time.
|
||||
dynamic_dropout = self.get_dynamic_dropout_rate()
|
||||
|
||||
# pooling module
|
||||
src = src + self.pooling(src,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.pooling(src,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att, attn_weights = self.self_attn(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
src = src + src_att
|
||||
use_self_attn = (random.random() > dynamic_dropout)
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src_att, attn_weights = self.self_attn(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
src = src + src_att
|
||||
|
||||
# convolution module
|
||||
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
src = src + self.self_attn.forward2(src, attn_weights)
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn.forward2(src, attn_weights)
|
||||
|
||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user