Treat batch_idx==0.0 separately to get scan_pessimistic_batches_for_oom() to work. should not affect results.
This commit is contained in:
parent
cde4ca27ee
commit
614b5b1a52
@ -363,8 +363,10 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
# layer_skip_prob will be overwritten to change warmup begin and end times.
|
||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05)),
|
||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0)),
|
||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
# to work correctly.
|
||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.5), (2000.0, 0.05)),
|
||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.2), (2000.0, 0.0)),
|
||||
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)),
|
||||
bypass_clamp_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
@ -552,11 +554,6 @@ class ZipformerEncoder(nn.Module):
|
||||
final_layerdrop_prob: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# will be written to, see set_batch_count() Note: in inference time this
|
||||
# may be zero but should be treated as large, we can check if
|
||||
# self.training is true.
|
||||
self.batch_count = 0
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
@ -570,7 +567,10 @@ class ZipformerEncoder(nn.Module):
|
||||
cur_begin = warmup_begin # interpreted as a training batch index
|
||||
for i in range(num_layers):
|
||||
cur_end = cur_begin + delta
|
||||
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
|
||||
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
self.layers[i].layer_skip_prob = ScheduledFloat((0.0, 0.0),
|
||||
(1.0, initial_layerdrop_prob),
|
||||
(cur_begin, initial_layerdrop_prob),
|
||||
(cur_end, final_layerdrop_prob))
|
||||
cur_begin = cur_end
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user