Merge branch 'scaled_adam_exp999' into scaled_adam_exp1002

This commit is contained in:
Daniel Povey 2023-02-13 12:49:49 +08:00
commit a5fb97d298

View File

@ -652,6 +652,26 @@ class ZipformerEncoderLayer(nn.Module):
# on which we have randomly chosen to do layer-skipping. # on which we have randomly chosen to do layer-skipping.
return ans return ans
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
return None
batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
return mask
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
"""
Apply sequence-level dropout to x.
x shape: (seq_len, batch_size, embed_dim)
"""
dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
if dropout_mask is None:
return x
else:
return x * dropout_mask
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -699,8 +719,9 @@ class ZipformerEncoderLayer(nn.Module):
# attention weights from another one. # attention weights from another one.
head_offset = 0 if self.self_attn_weights is not None else 2 head_offset = 0 if self.self_attn_weights is not None else 2
use_self_attn = (random.random() >= attention_skip_rate) self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
if use_self_attn:
if True:
selected_attn_weights = attn_weights[head_offset:head_offset+2] selected_attn_weights = attn_weights[head_offset:head_offset+2]
if random.random() < float(self.const_attention_rate): if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to # Make attention weights constant. The intention is to
@ -712,28 +733,23 @@ class ZipformerEncoderLayer(nn.Module):
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
if torch.jit.is_scripting() or use_self_attn:
src = src + self.balancer_na(self.nonlin_attention(src,
selected_attn_weights[0:1]))
na = self.balancer_na(self.nonlin_attention(src,
selected_attn_weights[0:1]))
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
## pooling module self_attn = self.self_attn(
#if torch.jit.is_scripting() or use_self_attn: src, attn_weights)
# src = src + self.balancer_as( src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
# self.attention_squeeze(src, selected_attn_weights[1:2]))
if torch.jit.is_scripting() or use_self_attn: src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size,
src = src + self.self_attn( src_key_padding_mask=src_key_padding_mask),
src, attn_weights) float(self.conv_skip_rate))
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
src = src + self.conv_module(src, chunk_size=chunk_size, float(self.ff2_skip_rate))
src_key_padding_mask=src_key_padding_mask)
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate):
src = src + self.balancer_ff2(self.feed_forward2(src))
src = self.balancer1(src) src = self.balancer1(src)
src = self.norm(src) src = self.norm(src)