Merge branch 'scaled_adam_exp416' into scaled_adam_exp418

This commit is contained in:
Daniel Povey 2022-11-17 18:34:07 +08:00
commit 8b50932d5a
2 changed files with 20 additions and 20 deletions

View File

@ -164,7 +164,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--pos-dim",
type=int,
default="128",
default="96",
help="Positional-encoding embedding dimension"
)

View File

@ -362,11 +362,11 @@ class ZipformerEncoderLayer(nn.Module):
feedforward_dim: int,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
# layer_skip_prob will be overwritten to change warmup begin and end times.
# layer_skip_rate will be overwritten to change warmup begin and end times.
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
# to work correctly.
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
bypass_max: FloatLike = 1.0,
) -> None:
@ -374,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module):
self.embed_dim = embed_dim
# probability of skipping the entire layer.
self.layer_skip_prob = copy.deepcopy(layer_skip_prob)
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
# skip probability for dynamic modules (meaning: anything but feedforward)
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob)
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero.
self.bypass_min = copy.deepcopy(bypass_min)
@ -459,16 +459,16 @@ class ZipformerEncoderLayer(nn.Module):
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
if self.training and random.random() < float(self.layer_skip_prob):
if self.training and random.random() < float(self.layer_skip_rate):
# skip the layer
return src
src_orig = src
# dropout rate for non-feedforward submodules
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
# multi-headed self-attention module
use_self_attn = (random.random() >= dynamic_skip_prob)
use_self_attn = (random.random() >= dynamic_skip_rate)
if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
@ -493,7 +493,7 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.self_attn(
src, attn_weights)
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob:
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward2(src)
@ -534,11 +534,11 @@ class ZipformerEncoder(nn.Module):
dropout: float,
warmup_begin: float,
warmup_end: float,
initial_layerdrop_prob: float = 0.5,
final_layerdrop_prob: float = 0.05,
initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05,
) -> None:
super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0)
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
@ -552,8 +552,8 @@ class ZipformerEncoder(nn.Module):
for i in range(num_layers):
cur_end = cur_begin + delta
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
(cur_end, final_layerdrop_prob),
self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
(cur_end, final_layerdrop_rate),
default=0.0)
cur_begin = cur_end
@ -1002,7 +1002,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim: dimension of the query (and key), per head. e.g. 24.
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
dropout: dropout probability for attn_output_weights. Default: 0.0.
pos_emb_skip: probability for skipping the pos_emb part of the scores on
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
any given call to forward(), in training time.
"""
@ -1014,8 +1014,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim: int,
pos_head_dim: int,
dropout: float = 0.0,
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.075))
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.0))
) -> None:
super().__init__()
self.embed_dim = embed_dim
@ -1023,7 +1023,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim
self.dropout = dropout
self.pos_emb_skip = copy.deepcopy(pos_emb_skip)
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
@ -1108,7 +1108,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k)
if not self.training or random.random() >= float(self.pos_emb_skip):
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)