Refactorize the scheduling code a little

This commit is contained in:
Daniel Povey 2022-11-14 14:52:14 +08:00
parent b32dec1119
commit e1fb25262a
2 changed files with 22 additions and 16 deletions

View File

@ -992,14 +992,19 @@ class ScheduledFloat(torch.nn.Module):
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0))
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or in training or mode or in
torch.jit scripting mode.
"""
def __init__(self,
*args):
*args,
default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = 0
self.name = ''
self.batch_count = None
self.name = None
self.default = default
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
@ -1012,19 +1017,21 @@ class ScheduledFloat(torch.nn.Module):
self.schedule)
def __float__(self):
print_prob = 0.0001
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count is None or self.training or torch.jit.is_scripting():
return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return ans
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return ans
return float(ans)
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
@ -1032,7 +1039,7 @@ class ScheduledFloat(torch.nn.Module):
if batch_count >= cur_x and batch_count <= next_x:
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return ans
return float(ans)
cur_x, cur_y = next_x, next_y
assert False

View File

@ -365,9 +365,9 @@ class ZipformerEncoderLayer(nn.Module):
# layer_skip_prob will be overwritten to change warmup begin and end times.
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
# to work correctly.
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.5), (2000.0, 0.05)),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.2), (2000.0, 0.0)),
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)),
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
bypass_clamp_max: FloatLike = 1.0,
) -> None:
super(ZipformerEncoderLayer, self).__init__()
@ -568,10 +568,9 @@ class ZipformerEncoder(nn.Module):
for i in range(num_layers):
cur_end = cur_begin + delta
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
self.layers[i].layer_skip_prob = ScheduledFloat((0.0, 0.0),
(1.0, initial_layerdrop_prob),
(cur_begin, initial_layerdrop_prob),
(cur_end, final_layerdrop_prob))
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
(cur_end, final_layerdrop_prob),
default=0.0)
cur_begin = cur_end
@ -1028,7 +1027,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_head_dim: int,
dropout: float = 0.0,
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.025))
(4000.0, 0.025), default=0.0)
) -> None:
super().__init__()
self.embed_dim = embed_dim