mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Refactorize the scheduling code a little
This commit is contained in:
parent
b32dec1119
commit
e1fb25262a
@ -992,14 +992,19 @@ class ScheduledFloat(torch.nn.Module):
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0))
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or in training or mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
*args):
|
||||
*args,
|
||||
default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = 0
|
||||
self.name = ''
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
@ -1012,19 +1017,21 @@ class ScheduledFloat(torch.nn.Module):
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0001
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return ans
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return ans
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
@ -1032,7 +1039,7 @@ class ScheduledFloat(torch.nn.Module):
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return ans
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
|
||||
@ -365,9 +365,9 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# layer_skip_prob will be overwritten to change warmup begin and end times.
|
||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
# to work correctly.
|
||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.5), (2000.0, 0.05)),
|
||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.2), (2000.0, 0.0)),
|
||||
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)),
|
||||
layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0),
|
||||
dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0),
|
||||
bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0),
|
||||
bypass_clamp_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
super(ZipformerEncoderLayer, self).__init__()
|
||||
@ -568,10 +568,9 @@ class ZipformerEncoder(nn.Module):
|
||||
for i in range(num_layers):
|
||||
cur_end = cur_begin + delta
|
||||
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
self.layers[i].layer_skip_prob = ScheduledFloat((0.0, 0.0),
|
||||
(1.0, initial_layerdrop_prob),
|
||||
(cur_begin, initial_layerdrop_prob),
|
||||
(cur_end, final_layerdrop_prob))
|
||||
self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob),
|
||||
(cur_end, final_layerdrop_prob),
|
||||
default=0.0)
|
||||
cur_begin = cur_end
|
||||
|
||||
|
||||
@ -1028,7 +1027,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
pos_head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
|
||||
(4000.0, 0.025))
|
||||
(4000.0, 0.025), default=0.0)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user