mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Refactor ScheduledFloat to include PiecewiseLinear
This commit is contained in:
parent
f688066517
commit
ae73469b7e
@ -31,6 +31,114 @@ from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
"""
|
||||
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
||||
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
|
||||
respectively.
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
assert len(args) >= 1
|
||||
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
|
||||
return args[0]
|
||||
for (x,y) in args:
|
||||
assert isinstance(x, float) or isinstance(x, int)
|
||||
assert isinstance(y, float) or isinstance(y, int)
|
||||
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.pairs = [ (float(x), float(y)) for x,y in args ]
|
||||
|
||||
def __str__(self):
|
||||
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
|
||||
return f'PiecewiseLinear({str(self.pairs)[1:-1]})'
|
||||
|
||||
def __call__(self, x):
|
||||
print("x=", x)
|
||||
if x <= self.pairs[0][0]:
|
||||
print("a", self.pairs[0][1])
|
||||
return self.pairs[0][1]
|
||||
elif x >= self.pairs[-1][0]:
|
||||
print("b")
|
||||
return self.pairs[-1][1]
|
||||
else:
|
||||
print("c")
|
||||
cur_x, cur_y = self.pairs[0]
|
||||
for i in range(1, len(self.pairs)):
|
||||
next_x, next_y = self.pairs[i]
|
||||
if x >= cur_x and x <= next_x:
|
||||
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
def __mul__(self, alpha):
|
||||
return PiecewiseLinear(
|
||||
* [(x, y * alpha) for x, y in self.pairs])
|
||||
|
||||
def __add__(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
return PiecewiseLinear(
|
||||
* [(p[0], p[1] + x) for p in self.pairs])
|
||||
s, x = self.get_common_basis(x)
|
||||
return PiecewiseLinear(
|
||||
* [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)])
|
||||
|
||||
def max(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
x = PiecewiseLinear( (0, x) )
|
||||
s, x = self.get_common_basis(x, include_crossings=True)
|
||||
return PiecewiseLinear(
|
||||
* [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)])
|
||||
|
||||
def min(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
x = PiecewiseLinear( (0, x) )
|
||||
s, x = self.get_common_basis(x, include_crossings=True)
|
||||
return PiecewiseLinear(
|
||||
* [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)])
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.pairs == other.pairs
|
||||
|
||||
def get_common_basis(self,
|
||||
p: 'PiecewiseLinear',
|
||||
include_crossings: bool = False):
|
||||
"""
|
||||
Returns (self_mod, p_mod) which are equivalent piecewise lienar
|
||||
functions to self and p, but with the same x values.
|
||||
|
||||
p: the other piecewise linear function
|
||||
include_crossings: if true, include in the x values positions
|
||||
where the functions indicate by this and p crosss.
|
||||
"""
|
||||
assert isinstance(p, PiecewiseLinear)
|
||||
|
||||
# get sorted x-values without repetition.
|
||||
x_vals = sorted(set([ x for x, y in self.pairs ] + [ x for x, y in p.pairs ]))
|
||||
y_vals1 = [ self(x) for x in x_vals ]
|
||||
y_vals2 = [ p(x) for x in x_vals ]
|
||||
|
||||
if include_crossings:
|
||||
extra_x_vals = []
|
||||
for i in range(len(x_vals) - 1):
|
||||
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]):
|
||||
# if the two lines in this subsegment potentially cross each other..
|
||||
diff_cur = abs(y_vals1[i] - y_vals2[i])
|
||||
diff_next = abs(y_vals1[i+1] - y_vals2[i+1])
|
||||
# `pos`, between 0 and 1, gives the relative x position,
|
||||
# with 0 being x_vals[i] and 1 being x_vals[i+1].
|
||||
pos = diff_cur / (diff_cur + diff_next)
|
||||
extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i])
|
||||
extra_x_vals.append(extra_x_val)
|
||||
if len(extra_x_vals) > 0:
|
||||
x_vals = sorted(set(x_vals + extra_x_vals))
|
||||
y_vals1 = [ self(x) for x in x_vals ]
|
||||
y_vals2 = [ p(x) for x in x_vals ]
|
||||
return ( PiecewiseLinear(* zip(x_vals, y_vals1)),
|
||||
PiecewiseLinear(* zip(x_vals, y_vals2)) )
|
||||
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
@ -56,43 +164,36 @@ class ScheduledFloat(torch.nn.Module):
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
self.schedule = PiecewiseLinear(*args)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}'
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
ans = self.schedule(self.batch_count)
|
||||
if random.random() < 0.0002:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
return ans
|
||||
|
||||
def __add__(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
return ScheduledFloat(self.schedule + x,
|
||||
default=self.default)
|
||||
else:
|
||||
return ScheduledFloat(self.schedule + x.schedule,
|
||||
default=self.default+x.default)
|
||||
|
||||
def max(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
return ScheduledFloat(self.schedule.max(x),
|
||||
default=self.default)
|
||||
else:
|
||||
return ScheduledFloat(self.schedule.max(x.schedule),
|
||||
default=max(self.default, x.default))
|
||||
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
@ -2083,10 +2184,41 @@ def _test_softmax():
|
||||
|
||||
|
||||
|
||||
def _test_piecewise_linear():
|
||||
p = PiecewiseLinear( (0, 10.0) )
|
||||
for x in [-100, 0, 100]:
|
||||
assert p(x) == 10.0
|
||||
p = PiecewiseLinear( (0, 10.0), (1, 0.0) )
|
||||
for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]:
|
||||
print("x, y = ", x, y)
|
||||
assert p(x) == y, (x, p(x), y)
|
||||
|
||||
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
|
||||
x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ]
|
||||
pq = p.max(q)
|
||||
for x in x_vals:
|
||||
y1 = max(p(x), q(x))
|
||||
y2 = pq(x)
|
||||
assert abs(y1 - y2) < 0.001
|
||||
pq = p.min(q)
|
||||
for x in x_vals:
|
||||
y1 = min(p(x), q(x))
|
||||
y2 = pq(x)
|
||||
assert abs(y1 - y2) < 0.001
|
||||
pq = p + q
|
||||
for x in x_vals:
|
||||
y1 = p(x) + q(x)
|
||||
y2 = pq(x)
|
||||
assert abs(y1 - y2) < 0.001
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_piecewise_linear()
|
||||
_test_softmax()
|
||||
_test_whiten()
|
||||
_test_max_eig()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user