Treat batch_idx==0.0 separately to get scan_pessimistic_batches_for_oom() to work. should not affect results.

This commit is contained in:
Daniel Povey 2022-11-14 13:20:31 +08:00
parent cde4ca27ee
commit 614b5b1a52

View File

@ -363,8 +363,10 @@ class ZipformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
# layer_skip_prob will be overwritten to change warmup begin and end times.
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05)),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0)),
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
# to work correctly.
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.5), (2000.0, 0.05)),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.2), (2000.0, 0.0)),
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)),
bypass_clamp_max: FloatLike = 1.0,
) -> None:
@ -552,11 +554,6 @@ class ZipformerEncoder(nn.Module):
final_layerdrop_prob: float = 0.05,
) -> None:
super().__init__()
# will be written to, see set_batch_count() Note: in inference time this
# may be zero but should be treated as large, we can check if
# self.training is true.
self.batch_count = 0
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
self.layers = nn.ModuleList(
@ -570,7 +567,10 @@ class ZipformerEncoder(nn.Module):
cur_begin = warmup_begin # interpreted as a training batch index
for i in range(num_layers):
cur_end = cur_begin + delta
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
self.layers[i].layer_skip_prob = ScheduledFloat((0.0, 0.0),
(1.0, initial_layerdrop_prob),
(cur_begin, initial_layerdrop_prob),
(cur_end, final_layerdrop_prob))
cur_begin = cur_end