mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp999' into scaled_adam_exp1002
This commit is contained in:
commit
a5fb97d298
@ -652,6 +652,26 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# on which we have randomly chosen to do layer-skipping.
|
# on which we have randomly chosen to do layer-skipping.
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||||
|
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||||
|
return None
|
||||||
|
batch_size = x.shape[1]
|
||||||
|
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
|
||||||
|
"""
|
||||||
|
Apply sequence-level dropout to x.
|
||||||
|
x shape: (seq_len, batch_size, embed_dim)
|
||||||
|
"""
|
||||||
|
dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
|
||||||
|
if dropout_mask is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x * dropout_mask
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -699,8 +719,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# attention weights from another one.
|
# attention weights from another one.
|
||||||
head_offset = 0 if self.self_attn_weights is not None else 2
|
head_offset = 0 if self.self_attn_weights is not None else 2
|
||||||
|
|
||||||
use_self_attn = (random.random() >= attention_skip_rate)
|
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||||
if use_self_attn:
|
|
||||||
|
if True:
|
||||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||||
if random.random() < float(self.const_attention_rate):
|
if random.random() < float(self.const_attention_rate):
|
||||||
# Make attention weights constant. The intention is to
|
# Make attention weights constant. The intention is to
|
||||||
@ -712,28 +733,23 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
|
||||||
src = src + self.balancer_na(self.nonlin_attention(src,
|
|
||||||
selected_attn_weights[0:1]))
|
|
||||||
|
|
||||||
|
na = self.balancer_na(self.nonlin_attention(src,
|
||||||
|
selected_attn_weights[0:1]))
|
||||||
|
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
## pooling module
|
self_attn = self.self_attn(
|
||||||
#if torch.jit.is_scripting() or use_self_attn:
|
src, attn_weights)
|
||||||
# src = src + self.balancer_as(
|
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||||
# self.attention_squeeze(src, selected_attn_weights[1:2]))
|
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size,
|
||||||
src = src + self.self_attn(
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
src, attn_weights)
|
float(self.conv_skip_rate))
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
||||||
src = src + self.conv_module(src, chunk_size=chunk_size,
|
float(self.ff2_skip_rate))
|
||||||
src_key_padding_mask=src_key_padding_mask)
|
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate):
|
|
||||||
src = src + self.balancer_ff2(self.feed_forward2(src))
|
|
||||||
|
|
||||||
src = self.balancer1(src)
|
src = self.balancer1(src)
|
||||||
src = self.norm(src)
|
src = self.norm(src)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user