mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp416' into scaled_adam_exp418
This commit is contained in:
commit
8b50932d5a
@ -164,7 +164,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pos-dim",
|
"--pos-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default="128",
|
default="96",
|
||||||
help="Positional-encoding embedding dimension"
|
help="Positional-encoding embedding dimension"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -362,11 +362,11 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
# layer_skip_prob will be overwritten to change warmup begin and end times.
|
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
||||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
||||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||||
bypass_max: FloatLike = 1.0,
|
bypass_max: FloatLike = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -374,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# probability of skipping the entire layer.
|
# probability of skipping the entire layer.
|
||||||
self.layer_skip_prob = copy.deepcopy(layer_skip_prob)
|
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward)
|
# skip probability for dynamic modules (meaning: anything but feedforward)
|
||||||
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
||||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||||
# ever becoming zero.
|
# ever becoming zero.
|
||||||
self.bypass_min = copy.deepcopy(bypass_min)
|
self.bypass_min = copy.deepcopy(bypass_min)
|
||||||
@ -459,16 +459,16 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
if self.training and random.random() < float(self.layer_skip_prob):
|
if self.training and random.random() < float(self.layer_skip_rate):
|
||||||
# skip the layer
|
# skip the layer
|
||||||
return src
|
return src
|
||||||
|
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
# dropout rate for non-feedforward submodules
|
# dropout rate for non-feedforward submodules
|
||||||
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
|
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
use_self_attn = (random.random() >= dynamic_skip_prob)
|
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
@ -493,7 +493,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.self_attn(
|
src = src + self.self_attn(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
@ -534,11 +534,11 @@ class ZipformerEncoder(nn.Module):
|
|||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float,
|
warmup_end: float,
|
||||||
initial_layerdrop_prob: float = 0.5,
|
initial_layerdrop_rate: float = 0.5,
|
||||||
final_layerdrop_prob: float = 0.05,
|
final_layerdrop_rate: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0)
|
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
@ -552,8 +552,8 @@ class ZipformerEncoder(nn.Module):
|
|||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
cur_end = cur_begin + delta
|
cur_end = cur_begin + delta
|
||||||
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
|
self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
|
||||||
(cur_end, final_layerdrop_prob),
|
(cur_end, final_layerdrop_rate),
|
||||||
default=0.0)
|
default=0.0)
|
||||||
cur_begin = cur_end
|
cur_begin = cur_end
|
||||||
|
|
||||||
@ -1002,7 +1002,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
||||||
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
||||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||||
pos_emb_skip: probability for skipping the pos_emb part of the scores on
|
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
||||||
any given call to forward(), in training time.
|
any given call to forward(), in training time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1014,8 +1014,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
query_head_dim: int,
|
query_head_dim: int,
|
||||||
pos_head_dim: int,
|
pos_head_dim: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
|
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
|
||||||
(4000.0, 0.075))
|
(4000.0, 0.0))
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -1023,7 +1023,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self.query_head_dim = query_head_dim
|
self.query_head_dim = query_head_dim
|
||||||
self.pos_head_dim = pos_head_dim
|
self.pos_head_dim = pos_head_dim
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.pos_emb_skip = copy.deepcopy(pos_emb_skip)
|
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||||
|
|
||||||
key_head_dim = query_head_dim
|
key_head_dim = query_head_dim
|
||||||
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
||||||
@ -1108,7 +1108,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
if not self.training or random.random() >= float(self.pos_emb_skip):
|
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
seq_len2 = 2 * seq_len - 1
|
seq_len2 = 2 * seq_len - 1
|
||||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user