Refactor ScheduledFloat to include PiecewiseLinear

This commit is contained in:
Daniel Povey 2023-01-04 20:46:42 +08:00
parent f688066517
commit ae73469b7e

View File

@ -31,6 +31,114 @@ from torch.nn import Embedding as ScaledEmbedding
class PiecewiseLinear(object):
"""
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
respectively.
"""
def __init__(self, *args):
assert len(args) >= 1
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
return args[0]
for (x,y) in args:
assert isinstance(x, float) or isinstance(x, int)
assert isinstance(y, float) or isinstance(y, int)
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.pairs = [ (float(x), float(y)) for x,y in args ]
def __str__(self):
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
return f'PiecewiseLinear({str(self.pairs)[1:-1]})'
def __call__(self, x):
print("x=", x)
if x <= self.pairs[0][0]:
print("a", self.pairs[0][1])
return self.pairs[0][1]
elif x >= self.pairs[-1][0]:
print("b")
return self.pairs[-1][1]
else:
print("c")
cur_x, cur_y = self.pairs[0]
for i in range(1, len(self.pairs)):
next_x, next_y = self.pairs[i]
if x >= cur_x and x <= next_x:
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
cur_x, cur_y = next_x, next_y
assert False
def __mul__(self, alpha):
return PiecewiseLinear(
* [(x, y * alpha) for x, y in self.pairs])
def __add__(self, x):
if isinstance(x, float) or isinstance(x, int):
return PiecewiseLinear(
* [(p[0], p[1] + x) for p in self.pairs])
s, x = self.get_common_basis(x)
return PiecewiseLinear(
* [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)])
def max(self, x):
if isinstance(x, float) or isinstance(x, int):
x = PiecewiseLinear( (0, x) )
s, x = self.get_common_basis(x, include_crossings=True)
return PiecewiseLinear(
* [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)])
def min(self, x):
if isinstance(x, float) or isinstance(x, int):
x = PiecewiseLinear( (0, x) )
s, x = self.get_common_basis(x, include_crossings=True)
return PiecewiseLinear(
* [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)])
def __eq__(self, other):
return self.pairs == other.pairs
def get_common_basis(self,
p: 'PiecewiseLinear',
include_crossings: bool = False):
"""
Returns (self_mod, p_mod) which are equivalent piecewise lienar
functions to self and p, but with the same x values.
p: the other piecewise linear function
include_crossings: if true, include in the x values positions
where the functions indicate by this and p crosss.
"""
assert isinstance(p, PiecewiseLinear)
# get sorted x-values without repetition.
x_vals = sorted(set([ x for x, y in self.pairs ] + [ x for x, y in p.pairs ]))
y_vals1 = [ self(x) for x in x_vals ]
y_vals2 = [ p(x) for x in x_vals ]
if include_crossings:
extra_x_vals = []
for i in range(len(x_vals) - 1):
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]):
# if the two lines in this subsegment potentially cross each other..
diff_cur = abs(y_vals1[i] - y_vals2[i])
diff_next = abs(y_vals1[i+1] - y_vals2[i+1])
# `pos`, between 0 and 1, gives the relative x position,
# with 0 being x_vals[i] and 1 being x_vals[i+1].
pos = diff_cur / (diff_cur + diff_next)
extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i])
extra_x_vals.append(extra_x_val)
if len(extra_x_vals) > 0:
x_vals = sorted(set(x_vals + extra_x_vals))
y_vals1 = [ self(x) for x in x_vals ]
y_vals2 = [ p(x) for x in x_vals ]
return ( PiecewiseLinear(* zip(x_vals, y_vals1)),
PiecewiseLinear(* zip(x_vals, y_vals2)) )
class ScheduledFloat(torch.nn.Module): class ScheduledFloat(torch.nn.Module):
""" """
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
@ -56,43 +164,36 @@ class ScheduledFloat(torch.nn.Module):
self.batch_count = None self.batch_count = None
self.name = None self.name = None
self.default = default self.default = default
assert len(args) >= 1 self.schedule = PiecewiseLinear(*args)
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str: def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count, return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}'
self.schedule)
def __float__(self): def __float__(self):
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting(): if batch_count is None or not self.training or torch.jit.is_scripting():
return float(self.default) return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return float(ans)
else: else:
cur_x, cur_y = self.schedule[0] ans = self.schedule(self.batch_count)
for i in range(1, len(self.schedule)): if random.random() < 0.0002:
next_x, next_y = self.schedule[i] logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
if batch_count >= cur_x and batch_count <= next_x: return ans
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans) def __add__(self, x):
return float(ans) if isinstance(x, float) or isinstance(x, int):
cur_x, cur_y = next_x, next_y return ScheduledFloat(self.schedule + x,
assert False default=self.default)
else:
return ScheduledFloat(self.schedule + x.schedule,
default=self.default+x.default)
def max(self, x):
if isinstance(x, float) or isinstance(x, int):
return ScheduledFloat(self.schedule.max(x),
default=self.default)
else:
return ScheduledFloat(self.schedule.max(x.schedule),
default=max(self.default, x.default))
FloatLike = Union[float, ScheduledFloat] FloatLike = Union[float, ScheduledFloat]
@ -2083,10 +2184,41 @@ def _test_softmax():
def _test_piecewise_linear():
p = PiecewiseLinear( (0, 10.0) )
for x in [-100, 0, 100]:
assert p(x) == 10.0
p = PiecewiseLinear( (0, 10.0), (1, 0.0) )
for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]:
print("x, y = ", x, y)
assert p(x) == y, (x, p(x), y)
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ]
pq = p.max(q)
for x in x_vals:
y1 = max(p(x), q(x))
y2 = pq(x)
assert abs(y1 - y2) < 0.001
pq = p.min(q)
for x in x_vals:
y1 = min(p(x), q(x))
y2 = pq(x)
assert abs(y1 - y2) < 0.001
pq = p + q
for x in x_vals:
y1 = p(x) + q(x)
y2 = pq(x)
assert abs(y1 - y2) < 0.001
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_piecewise_linear()
_test_softmax() _test_softmax()
_test_whiten() _test_whiten()
_test_max_eig() _test_max_eig()