From b973929d7c29e37cd7268709262d9cb223eb1482 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 4 Jan 2023 20:54:05 +0800 Subject: [PATCH] Bug fixes to ScheduledFloat --- .../ASR/pruned_transducer_stateless7/scaling.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ad7f7a484..5d88aeccc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -40,29 +40,27 @@ class PiecewiseLinear(object): def __init__(self, *args): assert len(args) >= 1 if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - return args[0] - for (x,y) in args: + self.pairs = list(args[0].pairs) + else: + self.pairs = [ (float(x), float(y)) for x,y in args ] + for (x,y) in self.pairs: assert isinstance(x, float) or isinstance(x, int) assert isinstance(y, float) or isinstance(y, int) - for i in range(len(args) - 1): - assert args[i + 1] > args[i], args - self.pairs = [ (float(x), float(y)) for x,y in args ] + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], self.pairs + def __str__(self): # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' return f'PiecewiseLinear({str(self.pairs)[1:-1]})' def __call__(self, x): - print("x=", x) if x <= self.pairs[0][0]: - print("a", self.pairs[0][1]) return self.pairs[0][1] elif x >= self.pairs[-1][0]: - print("b") return self.pairs[-1][1] else: - print("c") cur_x, cur_y = self.pairs[0] for i in range(1, len(self.pairs)): next_x, next_y = self.pairs[i]