diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index b380fa145..93d6d631b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -518,6 +518,75 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. return x + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or in training or mode or in + torch.jit scripting mode. + """ + def __init__(self, + *args, + default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + assert len(args) >= 1 + for (x,y) in args: + assert x >= 0 + for i in range(len(args) - 1): + assert args[i + 1] > args[i], args + self.schedule = args + + def extra_repr(self) -> str: + return 'batch_count={}, schedule={}'.format(self.batch_count, + self.schedule) + + def __float__(self): + print_prob = 0.0002 + def maybe_print(ans): + if random.random() < print_prob: + logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + batch_count = self.batch_count + if batch_count is None or not self.training or torch.jit.is_scripting(): + return float(self.default) + if batch_count <= self.schedule[0][0]: + ans = self.schedule[0][1] + maybe_print(ans) + return float(ans) + elif batch_count >= self.schedule[-1][0]: + ans = self.schedule[-1][1] + maybe_print(ans) + return float(ans) + else: + cur_x, cur_y = self.schedule[0] + for i in range(1, len(self.schedule)): + next_x, next_y = self.schedule[i] + if batch_count >= cur_x and batch_count <= next_x: + ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) + maybe_print(ans) + return float(ans) + cur_x, cur_y = next_x, next_y + assert False + + +FloatLike = Union[float, ScheduledFloat] + + def _whitening_metric(x: Tensor, num_groups: int): """ @@ -593,12 +662,11 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None - class Whiten(nn.Module): def __init__( self, num_groups: int, - whitening_limit: float, + whitening_limit: FloatLike, prob: Union[float, Tuple[float,float]], grad_scale: float): """ @@ -621,7 +689,7 @@ class Whiten(nn.Module): """ super(Whiten, self).__init__() assert num_groups >= 1 - assert whitening_limit >= 1 + assert float(whitening_limit) >= 1 assert grad_scale >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit @@ -656,10 +724,11 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: + whitening_limit = float(self.whitening_limit) if hasattr(self, 'min_prob') and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if _whitening_metric(x.to(torch.float32), self.num_groups) > whitening_limit: # there would be a change to the grad. self.prob = self.max_prob else: @@ -667,7 +736,7 @@ class Whiten(nn.Module): return WhiteningPenaltyFunction.apply(x, self.num_groups, - self.whitening_limit, + whitening_limit, self.grad_scale, self.name) @@ -1003,72 +1072,6 @@ class TanSwish(torch.nn.Module): return TanSwishFunction.apply(x) -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or in training or mode or in - torch.jit scripting mode. - """ - def __init__(self, - *args, - default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - assert len(args) >= 1 - for (x,y) in args: - assert x >= 0 - for i in range(len(args) - 1): - assert args[i + 1] > args[i], args - self.schedule = args - - def extra_repr(self) -> str: - return 'batch_count={}, schedule={}'.format(self.batch_count, - self.schedule) - - def __float__(self): - print_prob = 0.0002 - def maybe_print(ans): - if random.random() < print_prob: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") - batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting(): - return float(self.default) - if batch_count <= self.schedule[0][0]: - ans = self.schedule[0][1] - maybe_print(ans) - return float(ans) - elif batch_count >= self.schedule[-1][0]: - ans = self.schedule[-1][1] - maybe_print(ans) - return float(ans) - else: - cur_x, cur_y = self.schedule[0] - for i in range(1, len(self.schedule)): - next_x, next_y = self.schedule[i] - if batch_count >= cur_x and batch_count <= next_x: - ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) - maybe_print(ans) - return float(ans) - cur_x, cur_y = next_x, next_y - assert False - -FloatLike = Union[float, ScheduledFloat] - - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f1dc64cbf..f88b65d34 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -336,6 +336,11 @@ class Zipformer(EncoderInterface): return x, lengths +def _whitening_schedule(x: float) -> ScheduledFloat: + return ScheduledFloat((0.0, x), + (12000.0, 2.0 * x), + default=x) + class ZipformerEncoderLayer(nn.Module): """ Args: @@ -424,7 +429,7 @@ class ZipformerEncoderLayer(nn.Module): max_abs=6.0, ) self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, + whitening_limit=_whitening_schedule(4.0), prob=(0.025, 0.25), grad_scale=0.01) @@ -1048,9 +1053,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25) - # .. TODO: tune this limit? whitening_limit. self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, + whitening_limit=_whitening_schedule(2.0), prob=(0.025, 0.25), grad_scale=0.025) @@ -1227,7 +1231,7 @@ class SelfAttention(nn.Module): initial_scale=0.05) self.whiten = Whiten(num_groups=1, - whitening_limit=15.0, + whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), grad_scale=0.01) @@ -1331,7 +1335,7 @@ class AttentionSqueeze(nn.Module): bias=False, initial_scale=0.05) self.out_whiten = Whiten(num_groups=1, - whitening_limit=15.0, + whitening_limit=_whitening_schedule(7.5), prob=(0.01, 0.1), grad_scale=0.01) @@ -1388,7 +1392,7 @@ class FeedforwardModule(nn.Module): self.out_proj = ScaledLinear(feedforward_dim, embed_dim, initial_scale=0.01) self.out_whiten = Whiten(num_groups=1, - whitening_limit=15.0, + whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), grad_scale=0.01) @@ -1433,7 +1437,7 @@ class NonlinAttentionModule(nn.Module): initial_scale=0.05) self.whiten = Whiten(num_groups=1, - whitening_limit=15.0, + whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), grad_scale=0.01) @@ -1555,7 +1559,7 @@ class ConvolutionModule(nn.Module): ) self.out_whiten = Whiten(num_groups=1, - whitening_limit=15.0, + whitening_limit=_whitening_schedule(7.5), prob=(0.01, 0.1), grad_scale=0.01)