mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp999' into scaled_adam_exp1002
This commit is contained in:
commit
a5fb97d298
@ -652,6 +652,26 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# on which we have randomly chosen to do layer-skipping.
|
||||
return ans
|
||||
|
||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||
return None
|
||||
batch_size = x.shape[1]
|
||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||
return mask
|
||||
|
||||
|
||||
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
|
||||
"""
|
||||
Apply sequence-level dropout to x.
|
||||
x shape: (seq_len, batch_size, embed_dim)
|
||||
"""
|
||||
dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
|
||||
if dropout_mask is None:
|
||||
return x
|
||||
else:
|
||||
return x * dropout_mask
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -699,8 +719,9 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# attention weights from another one.
|
||||
head_offset = 0 if self.self_attn_weights is not None else 2
|
||||
|
||||
use_self_attn = (random.random() >= attention_skip_rate)
|
||||
if use_self_attn:
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
if True:
|
||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||
if random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
@ -712,28 +733,23 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.balancer_na(self.nonlin_attention(src,
|
||||
selected_attn_weights[0:1]))
|
||||
|
||||
na = self.balancer_na(self.nonlin_attention(src,
|
||||
selected_attn_weights[0:1]))
|
||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
## pooling module
|
||||
#if torch.jit.is_scripting() or use_self_attn:
|
||||
# src = src + self.balancer_as(
|
||||
# self.attention_squeeze(src, selected_attn_weights[1:2]))
|
||||
self_attn = self.self_attn(
|
||||
src, attn_weights)
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn(
|
||||
src, attn_weights)
|
||||
src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
float(self.conv_skip_rate))
|
||||
|
||||
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
||||
src = src + self.conv_module(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate):
|
||||
src = src + self.balancer_ff2(self.feed_forward2(src))
|
||||
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
||||
float(self.ff2_skip_rate))
|
||||
|
||||
src = self.balancer1(src)
|
||||
src = self.norm(src)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user