mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp416' into scaled_adam_exp418
This commit is contained in:
commit
8b50932d5a
@ -164,7 +164,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default="128",
|
||||
default="96",
|
||||
help="Positional-encoding embedding dimension"
|
||||
)
|
||||
|
||||
|
||||
@ -362,11 +362,11 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
feedforward_dim: int,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
# layer_skip_prob will be overwritten to change warmup begin and end times.
|
||||
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
# to work correctly.
|
||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||
bypass_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
@ -374,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# probability of skipping the entire layer.
|
||||
self.layer_skip_prob = copy.deepcopy(layer_skip_prob)
|
||||
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||
# skip probability for dynamic modules (meaning: anything but feedforward)
|
||||
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
|
||||
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||
# ever becoming zero.
|
||||
self.bypass_min = copy.deepcopy(bypass_min)
|
||||
@ -459,16 +459,16 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
if self.training and random.random() < float(self.layer_skip_prob):
|
||||
if self.training and random.random() < float(self.layer_skip_rate):
|
||||
# skip the layer
|
||||
return src
|
||||
|
||||
src_orig = src
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
|
||||
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
||||
# multi-headed self-attention module
|
||||
use_self_attn = (random.random() >= dynamic_skip_prob)
|
||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
@ -493,7 +493,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src = src + self.self_attn(
|
||||
src, attn_weights)
|
||||
|
||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
|
||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
@ -534,11 +534,11 @@ class ZipformerEncoder(nn.Module):
|
||||
dropout: float,
|
||||
warmup_begin: float,
|
||||
warmup_end: float,
|
||||
initial_layerdrop_prob: float = 0.5,
|
||||
final_layerdrop_prob: float = 0.05,
|
||||
initial_layerdrop_rate: float = 0.5,
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0)
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
@ -552,8 +552,8 @@ class ZipformerEncoder(nn.Module):
|
||||
for i in range(num_layers):
|
||||
cur_end = cur_begin + delta
|
||||
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
|
||||
(cur_end, final_layerdrop_prob),
|
||||
self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
|
||||
(cur_end, final_layerdrop_rate),
|
||||
default=0.0)
|
||||
cur_begin = cur_end
|
||||
|
||||
@ -1002,7 +1002,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
||||
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||
pos_emb_skip: probability for skipping the pos_emb part of the scores on
|
||||
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
||||
any given call to forward(), in training time.
|
||||
"""
|
||||
|
||||
@ -1014,8 +1014,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
query_head_dim: int,
|
||||
pos_head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
|
||||
(4000.0, 0.075))
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
|
||||
(4000.0, 0.0))
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -1023,7 +1023,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self.query_head_dim = query_head_dim
|
||||
self.pos_head_dim = pos_head_dim
|
||||
self.dropout = dropout
|
||||
self.pos_emb_skip = copy.deepcopy(pos_emb_skip)
|
||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||
|
||||
key_head_dim = query_head_dim
|
||||
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
||||
@ -1108,7 +1108,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
if not self.training or random.random() >= float(self.pos_emb_skip):
|
||||
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user