Bug fixes to ScheduledFloat

This commit is contained in:
Daniel Povey 2023-01-04 20:54:05 +08:00
parent ae73469b7e
commit b973929d7c

View File

@ -40,29 +40,27 @@ class PiecewiseLinear(object):
def __init__(self, *args):
assert len(args) >= 1
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
return args[0]
for (x,y) in args:
self.pairs = list(args[0].pairs)
else:
self.pairs = [ (float(x), float(y)) for x,y in args ]
for (x,y) in self.pairs:
assert isinstance(x, float) or isinstance(x, int)
assert isinstance(y, float) or isinstance(y, int)
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.pairs = [ (float(x), float(y)) for x,y in args ]
for i in range(len(self.pairs) - 1):
assert self.pairs[i + 1][0] > self.pairs[i][0], self.pairs
def __str__(self):
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
return f'PiecewiseLinear({str(self.pairs)[1:-1]})'
def __call__(self, x):
print("x=", x)
if x <= self.pairs[0][0]:
print("a", self.pairs[0][1])
return self.pairs[0][1]
elif x >= self.pairs[-1][0]:
print("b")
return self.pairs[-1][1]
else:
print("c")
cur_x, cur_y = self.pairs[0]
for i in range(1, len(self.pairs)):
next_x, next_y = self.pairs[i]