diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0b2ab23d5..ddce1217c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -362,11 +362,11 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim: int, dropout: float = 0.1, cnn_module_kernel: int = 31, - # layer_skip_prob will be overwritten to change warmup begin and end times. + # layer_skip_rate will be overwritten to change warmup begin and end times. # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. - layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0), - dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0), + layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0), + dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_max: FloatLike = 1.0, ) -> None: @@ -374,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module): self.embed_dim = embed_dim # probability of skipping the entire layer. - self.layer_skip_prob = copy.deepcopy(layer_skip_prob) + self.layer_skip_rate = copy.deepcopy(layer_skip_rate) # skip probability for dynamic modules (meaning: anything but feedforward) - self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob) + self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. self.bypass_min = copy.deepcopy(bypass_min) @@ -466,7 +466,7 @@ class ZipformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - if self.training and random.random() < float(self.layer_skip_prob): + if self.training and random.random() < float(self.layer_skip_rate): # skip the layer return src @@ -476,9 +476,9 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward1(src) # dropout rate for non-feedforward submodules - dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0 + dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 # multi-headed self-attention module - use_self_attn = (random.random() >= dynamic_skip_prob) + use_self_attn = (random.random() >= dynamic_skip_rate) if torch.jit.is_scripting() or use_self_attn: # attn_weights: (num_heads, batch_size, seq_len, seq_len) @@ -508,7 +508,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn2( src, attn_weights) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: + if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward3(src) @@ -549,8 +549,8 @@ class ZipformerEncoder(nn.Module): dropout: float, warmup_begin: float, warmup_end: float, - initial_layerdrop_prob: float = 0.5, - final_layerdrop_prob: float = 0.05, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0) @@ -567,8 +567,8 @@ class ZipformerEncoder(nn.Module): for i in range(num_layers): cur_end = cur_begin + delta # treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom() - self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob), - (cur_end, final_layerdrop_prob), + self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), default=0.0) cur_begin = cur_end @@ -1018,7 +1018,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim: dimension of the query (and key), per head. e.g. 24. pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip: probability for skipping the pos_emb part of the scores on + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on any given call to forward(), in training time. """ @@ -1030,8 +1030,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim: int, pos_head_dim: int, dropout: float = 0.0, - pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.05)) + pos_emb_skip_rate: FloatLike = 0.05, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1039,7 +1038,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim self.dropout = dropout - self.pos_emb_skip = copy.deepcopy(pos_emb_skip) + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1124,7 +1123,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) - if not self.training or random.random() >= float(self.pos_emb_skip): + if not self.training or random.random() >= float(self.pos_emb_skip_rate): pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)