mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make pos_emb dropout rate be constant during training; also cosmetic changes
This commit is contained in:
parent
867556200f
commit
f76075fd1a
@ -362,11 +362,11 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
# layer_skip_prob will be overwritten to change warmup begin and end times.
|
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
||||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
|
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
|
||||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
|
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
|
||||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||||
bypass_max: FloatLike = 1.0,
|
bypass_max: FloatLike = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -374,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# probability of skipping the entire layer.
|
# probability of skipping the entire layer.
|
||||||
self.layer_skip_prob = copy.deepcopy(layer_skip_prob)
|
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward)
|
# skip probability for dynamic modules (meaning: anything but feedforward)
|
||||||
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
||||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||||
# ever becoming zero.
|
# ever becoming zero.
|
||||||
self.bypass_min = copy.deepcopy(bypass_min)
|
self.bypass_min = copy.deepcopy(bypass_min)
|
||||||
@ -466,7 +466,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
if self.training and random.random() < float(self.layer_skip_prob):
|
if self.training and random.random() < float(self.layer_skip_rate):
|
||||||
# skip the layer
|
# skip the layer
|
||||||
return src
|
return src
|
||||||
|
|
||||||
@ -476,9 +476,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
# dropout rate for non-feedforward submodules
|
# dropout rate for non-feedforward submodules
|
||||||
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
|
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
use_self_attn = (random.random() >= dynamic_skip_prob)
|
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
@ -508,7 +508,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.self_attn2(
|
src = src + self.self_attn2(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward3(src)
|
||||||
@ -549,8 +549,8 @@ class ZipformerEncoder(nn.Module):
|
|||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float,
|
warmup_end: float,
|
||||||
initial_layerdrop_prob: float = 0.5,
|
initial_layerdrop_rate: float = 0.5,
|
||||||
final_layerdrop_prob: float = 0.05,
|
final_layerdrop_rate: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0)
|
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0)
|
||||||
@ -567,8 +567,8 @@ class ZipformerEncoder(nn.Module):
|
|||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
cur_end = cur_begin + delta
|
cur_end = cur_begin + delta
|
||||||
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
|
self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
|
||||||
(cur_end, final_layerdrop_prob),
|
(cur_end, final_layerdrop_rate),
|
||||||
default=0.0)
|
default=0.0)
|
||||||
cur_begin = cur_end
|
cur_begin = cur_end
|
||||||
|
|
||||||
@ -1018,7 +1018,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
||||||
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
||||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||||
pos_emb_skip: probability for skipping the pos_emb part of the scores on
|
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
||||||
any given call to forward(), in training time.
|
any given call to forward(), in training time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1030,8 +1030,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
query_head_dim: int,
|
query_head_dim: int,
|
||||||
pos_head_dim: int,
|
pos_head_dim: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
|
pos_emb_skip_rate: FloatLike = 0.05,
|
||||||
(4000.0, 0.05))
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -1039,7 +1038,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self.query_head_dim = query_head_dim
|
self.query_head_dim = query_head_dim
|
||||||
self.pos_head_dim = pos_head_dim
|
self.pos_head_dim = pos_head_dim
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.pos_emb_skip = copy.deepcopy(pos_emb_skip)
|
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||||
|
|
||||||
key_head_dim = query_head_dim
|
key_head_dim = query_head_dim
|
||||||
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
||||||
@ -1124,7 +1123,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
if not self.training or random.random() >= float(self.pos_emb_skip):
|
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
seq_len2 = 2 * seq_len - 1
|
seq_len2 = 2 * seq_len - 1
|
||||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user