Make most forms of sequence dropout be separate per sequence.

This commit is contained in:
Daniel Povey 2023-02-10 16:34:01 +08:00
parent e7e7560bba
commit ad388890d9

View File

@ -623,6 +623,26 @@ class ZipformerEncoderLayer(nn.Module):
# on which we have randomly chosen to do layer-skipping. # on which we have randomly chosen to do layer-skipping.
return ans return ans
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
return None
batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
return mask
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
"""
Apply sequence-level dropout to x.
x shape: (seq_len, batch_size, embed_dim)
"""
dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
if dropout_mask is None:
return x
else:
return x * dropout_mask
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -670,8 +690,9 @@ class ZipformerEncoderLayer(nn.Module):
# attention weights from another one. # attention weights from another one.
head_offset = 0 if self.self_attn_weights is not None else 2 head_offset = 0 if self.self_attn_weights is not None else 2
use_self_attn = (random.random() >= attention_skip_rate) self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
if use_self_attn:
if True:
selected_attn_weights = attn_weights[head_offset:head_offset+2] selected_attn_weights = attn_weights[head_offset:head_offset+2]
if random.random() < float(self.const_attention_rate): if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to # Make attention weights constant. The intention is to
@ -683,28 +704,23 @@ class ZipformerEncoderLayer(nn.Module):
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
if torch.jit.is_scripting() or use_self_attn:
src = src + self.balancer_na(self.nonlin_attention(src,
selected_attn_weights[0:1]))
na = self.balancer_na(self.nonlin_attention(src,
selected_attn_weights[0:1]))
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
## pooling module self_attn = self.self_attn(
#if torch.jit.is_scripting() or use_self_attn:
# src = src + self.balancer_as(
# self.attention_squeeze(src, selected_attn_weights[1:2]))
if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn(
src, attn_weights) src, attn_weights)
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size,
src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask),
src_key_padding_mask=src_key_padding_mask) float(self.conv_skip_rate))
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate): src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
src = src + self.balancer_ff2(self.feed_forward2(src)) float(self.ff2_skip_rate))
src = self.balancer1(src) src = self.balancer1(src)
src = self.norm(src) src = self.norm(src)