diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ef998eb4a..5d88aeccc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -31,6 +31,112 @@ from torch.nn import Embedding as ScaledEmbedding +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + def __init__(self, *args): + assert len(args) >= 1 + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [ (float(x), float(y)) for x,y in args ] + for (x,y) in self.pairs: + assert isinstance(x, float) or isinstance(x, int) + assert isinstance(y, float) or isinstance(y, int) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], self.pairs + + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear( + * [(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return PiecewiseLinear( + * [(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear( (0, x) ) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear( (0, x) ) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, + p: 'PiecewiseLinear', + include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise lienar + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p crosss. + """ + assert isinstance(p, PiecewiseLinear) + + # get sorted x-values without repetition. + x_vals = sorted(set([ x for x, y in self.pairs ] + [ x for x, y in p.pairs ])) + y_vals1 = [ self(x) for x in x_vals ] + y_vals2 = [ p(x) for x in x_vals ] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [ self(x) for x in x_vals ] + y_vals2 = [ p(x) for x in x_vals ] + return ( PiecewiseLinear(* zip(x_vals, y_vals1)), + PiecewiseLinear(* zip(x_vals, y_vals2)) ) + + + class ScheduledFloat(torch.nn.Module): """ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); @@ -56,43 +162,36 @@ class ScheduledFloat(torch.nn.Module): self.batch_count = None self.name = None self.default = default - assert len(args) >= 1 - for (x,y) in args: - assert x >= 0 - for i in range(len(args) - 1): - assert args[i + 1] > args[i], args - self.schedule = args + self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: - return 'batch_count={}, schedule={}'.format(self.batch_count, - self.schedule) + return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' def __float__(self): - print_prob = 0.0002 - def maybe_print(ans): - if random.random() < print_prob: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") batch_count = self.batch_count if batch_count is None or not self.training or torch.jit.is_scripting(): return float(self.default) - if batch_count <= self.schedule[0][0]: - ans = self.schedule[0][1] - maybe_print(ans) - return float(ans) - elif batch_count >= self.schedule[-1][0]: - ans = self.schedule[-1][1] - maybe_print(ans) - return float(ans) else: - cur_x, cur_y = self.schedule[0] - for i in range(1, len(self.schedule)): - next_x, next_y = self.schedule[i] - if batch_count >= cur_x and batch_count <= next_x: - ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) - maybe_print(ans) - return float(ans) - cur_x, cur_y = next_x, next_y - assert False + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, + default=self.default) + else: + return ScheduledFloat(self.schedule + x.schedule, + default=self.default+x.default) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), + default=self.default) + else: + return ScheduledFloat(self.schedule.max(x.schedule), + default=max(self.default, x.default)) FloatLike = Union[float, ScheduledFloat] @@ -2083,10 +2182,41 @@ def _test_softmax(): +def _test_piecewise_linear(): + p = PiecewiseLinear( (0, 10.0) ) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) + for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_piecewise_linear() _test_softmax() _test_whiten() _test_max_eig() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5396b1895..58b2e6cb2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -207,6 +207,15 @@ class Zipformer(EncoderInterface): attention_share_layers=attention_share_layers[i], ) + # modify the layerdrop schedule with an extra schedule that takes longer + # to warm up for the less-downsampled layers; this encourages the more + # heavily downsampled layers to learn something. + + extra_layerdrop = ScheduledFloat((0.0, 0.2), (20000.0 / downsampling_factor[i], 0.0)) + for layer in encoder.layers: + # we can add objects of type ScheduledFloat. + layer.layer_skip_rate = layer.layer_skip_rate + extra_layerdrop + if downsampling_factor[i] != 1: encoder = DownsampledZipformerEncoder( encoder, @@ -220,6 +229,9 @@ class Zipformer(EncoderInterface): encoder.lr_scale = downsampling_factor[i] ** -0.25 encoders.append(encoder) + + + self.encoders = nn.ModuleList(encoders) # initializes self.skip_layers and self.skip_modules