From 614b5b1a52fc3c539ed7333a0244fab2a530931b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Nov 2022 13:20:31 +0800 Subject: [PATCH] Treat batch_idx==0.0 separately to get scan_pessimistic_batches_for_oom() to work. should not affect results. --- .../pruned_transducer_stateless7/zipformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a8c2afd75..f42022498 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -363,8 +363,10 @@ class ZipformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, # layer_skip_prob will be overwritten to change warmup begin and end times. - layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05)), - dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0)), + # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() + # to work correctly. + layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.5), (2000.0, 0.05)), + dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.2), (2000.0, 0.0)), bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)), bypass_clamp_max: FloatLike = 1.0, ) -> None: @@ -552,11 +554,6 @@ class ZipformerEncoder(nn.Module): final_layerdrop_prob: float = 0.05, ) -> None: super().__init__() - # will be written to, see set_batch_count() Note: in inference time this - # may be zero but should be treated as large, we can check if - # self.training is true. - self.batch_count = 0 - self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15) self.layers = nn.ModuleList( @@ -570,7 +567,10 @@ class ZipformerEncoder(nn.Module): cur_begin = warmup_begin # interpreted as a training batch index for i in range(num_layers): cur_end = cur_begin + delta - self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob), + # treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom() + self.layers[i].layer_skip_prob = ScheduledFloat((0.0, 0.0), + (1.0, initial_layerdrop_prob), + (cur_begin, initial_layerdrop_prob), (cur_end, final_layerdrop_prob)) cur_begin = cur_end