Introduce dropout rate to dynamic submodules of conformer.

This commit is contained in:
Daniel Povey 2022-10-31 16:18:52 +08:00
parent 3de8a5aef2
commit 12f17f550e

View File

@ -383,6 +383,21 @@ class ZipformerEncoderLayer(nn.Module):
(self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min))
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
def get_dynamic_dropout_rate(self):
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
# at the beginning, by making the network focus on the feedforward modules.
if torch.jit.is_scripting() or not self.training:
return 0.0
warmup_period = 2000.0
initial_dropout_rate = 0.2
final_dropout_rate = 0.0
if self.batch_count > warmup_period:
return final_dropout_rate
else:
return (initial_dropout_rate -
(initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period))
def forward(
self,
src: Tensor,
@ -412,28 +427,37 @@ class ZipformerEncoderLayer(nn.Module):
# macaron style feed forward module
src = src + self.feed_forward1(src)
# dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate()
# pooling module
src = src + self.pooling(src,
key_padding_mask=src_key_padding_mask)
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.pooling(src,
key_padding_mask=src_key_padding_mask)
# multi-headed self-attention module
src_att, attn_weights = self.self_attn(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + src_att
use_self_attn = (random.random() > dynamic_dropout)
if torch.jit.is_scripting() or use_self_attn:
src_att, attn_weights = self.self_attn(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + src_att
# convolution module
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward2(src)
src = src + self.self_attn.forward2(src, attn_weights)
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn.forward2(src, attn_weights)
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward3(src)