Make pos_emb dropout rate be constant during training; also cosmetic changes

This commit is contained in:
Daniel Povey 2022-11-15 11:42:12 +08:00
parent 867556200f
commit f76075fd1a

View File

@ -362,11 +362,11 @@ class ZipformerEncoderLayer(nn.Module):
feedforward_dim: int, feedforward_dim: int,
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
# layer_skip_prob will be overwritten to change warmup begin and end times. # layer_skip_rate will be overwritten to change warmup begin and end times.
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
# to work correctly. # to work correctly.
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0), layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0), dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
bypass_max: FloatLike = 1.0, bypass_max: FloatLike = 1.0,
) -> None: ) -> None:
@ -374,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
# probability of skipping the entire layer. # probability of skipping the entire layer.
self.layer_skip_prob = copy.deepcopy(layer_skip_prob) self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
# skip probability for dynamic modules (meaning: anything but feedforward) # skip probability for dynamic modules (meaning: anything but feedforward)
self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob) self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero. # ever becoming zero.
self.bypass_min = copy.deepcopy(bypass_min) self.bypass_min = copy.deepcopy(bypass_min)
@ -466,7 +466,7 @@ class ZipformerEncoderLayer(nn.Module):
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
""" """
if self.training and random.random() < float(self.layer_skip_prob): if self.training and random.random() < float(self.layer_skip_rate):
# skip the layer # skip the layer
return src return src
@ -476,9 +476,9 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
# dropout rate for non-feedforward submodules # dropout rate for non-feedforward submodules
dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0 dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
# multi-headed self-attention module # multi-headed self-attention module
use_self_attn = (random.random() >= dynamic_skip_prob) use_self_attn = (random.random() >= dynamic_skip_rate)
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # attn_weights: (num_heads, batch_size, seq_len, seq_len)
@ -508,7 +508,7 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.self_attn2( src = src + self.self_attn2(
src, attn_weights) src, attn_weights)
if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward3(src) src = src + self.feed_forward3(src)
@ -549,8 +549,8 @@ class ZipformerEncoder(nn.Module):
dropout: float, dropout: float,
warmup_begin: float, warmup_begin: float,
warmup_end: float, warmup_end: float,
initial_layerdrop_prob: float = 0.5, initial_layerdrop_rate: float = 0.5,
final_layerdrop_prob: float = 0.05, final_layerdrop_rate: float = 0.05,
) -> None: ) -> None:
super().__init__() super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0) self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0)
@ -567,8 +567,8 @@ class ZipformerEncoder(nn.Module):
for i in range(num_layers): for i in range(num_layers):
cur_end = cur_begin + delta cur_end = cur_begin + delta
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom() # treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob), self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
(cur_end, final_layerdrop_prob), (cur_end, final_layerdrop_rate),
default=0.0) default=0.0)
cur_begin = cur_end cur_begin = cur_end
@ -1018,7 +1018,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim: dimension of the query (and key), per head. e.g. 24. query_head_dim: dimension of the query (and key), per head. e.g. 24.
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
dropout: dropout probability for attn_output_weights. Default: 0.0. dropout: dropout probability for attn_output_weights. Default: 0.0.
pos_emb_skip: probability for skipping the pos_emb part of the scores on pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
any given call to forward(), in training time. any given call to forward(), in training time.
""" """
@ -1030,8 +1030,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim: int, query_head_dim: int,
pos_head_dim: int, pos_head_dim: int,
dropout: float = 0.0, dropout: float = 0.0,
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5), pos_emb_skip_rate: FloatLike = 0.05,
(4000.0, 0.05))
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -1039,7 +1038,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.query_head_dim = query_head_dim self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim self.pos_head_dim = pos_head_dim
self.dropout = dropout self.dropout = dropout
self.pos_emb_skip = copy.deepcopy(pos_emb_skip) self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
key_head_dim = query_head_dim key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
@ -1124,7 +1123,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k) attn_scores = torch.matmul(q, k)
if not self.training or random.random() >= float(self.pos_emb_skip): if not self.training or random.random() >= float(self.pos_emb_skip_rate):
pos_emb = self.linear_pos(pos_emb) pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1 seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)