diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index e5ab058fe..8cf451d3e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -164,7 +164,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--pos-dim", type=int, - default="128", + default="96", help="Positional-encoding embedding dimension" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4e7261fd0..efcb25754 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -362,11 +362,11 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim: int, dropout: float = 0.1, cnn_module_kernel: int = 31, - # layer_skip_prob will be overwritten to change warmup begin and end times. + # layer_skip_rate will be overwritten to change warmup begin and end times. # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. - layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), - dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), + layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), + dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_max: FloatLike = 1.0, ) -> None: @@ -374,9 +374,9 @@ class ZipformerEncoderLayer(nn.Module): self.embed_dim = embed_dim # probability of skipping the entire layer. - self.layer_skip_prob = copy.deepcopy(layer_skip_prob) + self.layer_skip_rate = copy.deepcopy(layer_skip_rate) # skip probability for dynamic modules (meaning: anything but feedforward) - self.dynamic_skip_prob = copy.deepcopy(dynamic_skip_prob) + self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. self.bypass_min = copy.deepcopy(bypass_min) @@ -459,16 +459,16 @@ class ZipformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - if self.training and random.random() < float(self.layer_skip_prob): + if self.training and random.random() < float(self.layer_skip_rate): # skip the layer return src src_orig = src # dropout rate for non-feedforward submodules - dynamic_skip_prob = float(self.dynamic_skip_prob) if self.training else 0.0 + dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 # multi-headed self-attention module - use_self_attn = (random.random() >= dynamic_skip_prob) + use_self_attn = (random.random() >= dynamic_skip_rate) if torch.jit.is_scripting() or use_self_attn: # attn_weights: (num_heads, batch_size, seq_len, seq_len) @@ -493,7 +493,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn( src, attn_weights) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_prob: + if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) @@ -534,11 +534,11 @@ class ZipformerEncoder(nn.Module): dropout: float, warmup_begin: float, warmup_end: float, - initial_layerdrop_prob: float = 0.5, - final_layerdrop_prob: float = 0.05, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.0) + self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -552,8 +552,8 @@ class ZipformerEncoder(nn.Module): for i in range(num_layers): cur_end = cur_begin + delta # treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom() - self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob), - (cur_end, final_layerdrop_prob), + self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), default=0.0) cur_begin = cur_end @@ -1002,7 +1002,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim: dimension of the query (and key), per head. e.g. 24. pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip: probability for skipping the pos_emb part of the scores on + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on any given call to forward(), in training time. """ @@ -1014,8 +1014,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim: int, pos_head_dim: int, dropout: float = 0.0, - pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.075)) + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), + (4000.0, 0.0)) ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1023,7 +1023,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim self.dropout = dropout - self.pos_emb_skip = copy.deepcopy(pos_emb_skip) + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1108,7 +1108,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) - if not self.training or random.random() >= float(self.pos_emb_skip): + if not self.training or random.random() >= float(self.pos_emb_skip_rate): pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)