diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 93d6d631b..0fcd73878 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -29,6 +29,76 @@ from torch import Tensor from torch.nn import Embedding as ScaledEmbedding + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or in training or mode or in + torch.jit scripting mode. + """ + def __init__(self, + *args, + default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + assert len(args) >= 1 + for (x,y) in args: + assert x >= 0 + for i in range(len(args) - 1): + assert args[i + 1] > args[i], args + self.schedule = args + + def extra_repr(self) -> str: + return 'batch_count={}, schedule={}'.format(self.batch_count, + self.schedule) + + def __float__(self): + print_prob = 0.0002 + def maybe_print(ans): + if random.random() < print_prob: + logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + batch_count = self.batch_count + if batch_count is None or not self.training or torch.jit.is_scripting(): + return float(self.default) + if batch_count <= self.schedule[0][0]: + ans = self.schedule[0][1] + maybe_print(ans) + return float(ans) + elif batch_count >= self.schedule[-1][0]: + ans = self.schedule[-1][1] + maybe_print(ans) + return float(ans) + else: + cur_x, cur_y = self.schedule[0] + for i in range(1, len(self.schedule)): + next_x, next_y = self.schedule[i] + if batch_count >= cur_x and batch_count <= next_x: + ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) + maybe_print(ans) + return float(ans) + cur_x, cur_y = next_x, next_y + assert False + + +FloatLike = Union[float, ScheduledFloat] + + + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( @@ -317,6 +387,110 @@ class BasicNorm(torch.nn.Module): return x * scales +class LinearWithAuxLossFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, alpha: Tensor, + aux_grad_scale: float) -> Tensor: + """ + Returns matmul(x, weight.t()). + In the backward pass it will include an auxiliary loss based on predicting x from + matmul(y, weight). + """ + ctx.save_for_backward(x, weight, alpha) + ctx.aux_grad_scale = aux_grad_scale + return torch.matmul(x, weight.t()) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]: + x, weight, alpha = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x, weight, alpha = x.detach(), weight.detach(), alpha.detach() + weight.requires_grad = True + alpha.requires_grad = True + # recompute y as we need the gradient; this is easier to implement than + # saving y in the context. + y = torch.matmul(x, weight.t()) + z = alpha * torch.matmul(y, weight) + diff = x - z + dims_to_mean = tuple(range(x.ndim-1)) + mean = diff.mean(dim=dims_to_mean) + diff = diff - mean # subtract mean. + # meansq is the loss function. + meansq = (diff ** 2).mean() + meansq.backward() + weight_aux_grad = weight.grad + alpha_grad = alpha.grad + + x_grad = torch.matmul(ans_grad, weight) + weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), + x.reshape(-1, x.shape[-1])) + + with torch.cuda.amp.autocast(enabled=False): + weight_grad_norm = weight_grad.to(torch.float32).norm() + aux_grad_norm = weight_aux_grad.norm() + weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20) + weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype) + + return x_grad, weight_grad, alpha_grad, None + + + +class LinearWithAuxLoss(nn.Module): + """ + A linear layer with an auxiliary loss that you can put on a schedule, that + encourages it to correspond to the largest-variance directions of the + input features. + + Suppose the input is x, and this layer computes: + y = M x + (the bias is applied separately), then we define: + z = alpha * M^T y + where alpha is learnable; and the auxiliary loss will be: + aux_loss = normalize_mean(z - x)^2. + (normalize_mean refers to subtracting the average value per channel, + over the minibatch). + In the backward pass we compute the derivative of the auxiliary loss + and add it to the weight and bias grads, with a scale chosen such + that the extra grad's norm equals `aux_grad_scales` times the norm + of the existing grad. + """ + def __init__(self, + in_channels: int, + out_channels: int, + bias: bool = True, + aux_grad_scale: Optional[FloatLike] = None, + prob: FloatLike = 0.25, + initial_scale: float = 1.0, + ): + super().__init__() + if aux_grad_scale is None: + aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1), + (2000.0, 0.01), (8000.0, 0.0)) + + self.aux_grad_scale = aux_grad_scale + self.prob = prob + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels) + * (in_channels ** -0.5) * initial_scale) + if bias: + self.bias = nn.Parameter(torch.randn(out_channels) * + 0.01 * initial_scale) + else: + self.register_parameter('bias', None) + self.alpha = nn.Parameter(torch.tensor(1.0)) + + + def forward(self, + x: Tensor): + aux_grad_scale = float(self.aux_grad_scale) + if (not self.training or torch.jit.is_scripting() or + aux_grad_scale == 0.0 or random.random() > float(self.prob)): + return torch.matmul(x, self.weight.t()) + self.bias + else: + return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, + aux_grad_scale) + self.bias def ScaledLinear(*args, initial_scale: float = 1.0, @@ -417,14 +591,14 @@ class ActivationBalancer(torch.nn.Module): self, num_channels: int, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + max_factor: FloatLike = 0.04, + sign_gain_factor: FloatLike = 0.01, + scale_gain_factor: FloatLike = 0.02, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + min_prob: FloatLike = 0.1, ): super(ActivationBalancer, self).__init__() # CAUTION: this code expects self.batch_count to be overwritten in the main training @@ -453,25 +627,26 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0))) + prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0))) if random.random() < prob: assert x.shape[self.channel_dim] == self.num_channels sign_gain_factor = 0.5 - if self.min_positive != 0.0 or self.max_positive != 1.0: + if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0: sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + float(self.min_positive), + float(self.max_positive), + gain_factor=float(self.sign_gain_factor) / prob, + max_factor=float(self.max_factor)) else: sign_factor = None scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + min_abs=float(self.min_abs), + max_abs=float(self.max_abs), + gain_factor=float(self.scale_gain_factor) / prob, + max_factor=float(self.max_factor)) return ActivationBalancerFunction.apply( x, scale_factor, sign_factor, self.channel_dim, ) @@ -519,74 +694,6 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or in training or mode or in - torch.jit scripting mode. - """ - def __init__(self, - *args, - default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - assert len(args) >= 1 - for (x,y) in args: - assert x >= 0 - for i in range(len(args) - 1): - assert args[i + 1] > args[i], args - self.schedule = args - - def extra_repr(self) -> str: - return 'batch_count={}, schedule={}'.format(self.batch_count, - self.schedule) - - def __float__(self): - print_prob = 0.0002 - def maybe_print(ans): - if random.random() < print_prob: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") - batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting(): - return float(self.default) - if batch_count <= self.schedule[0][0]: - ans = self.schedule[0][1] - maybe_print(ans) - return float(ans) - elif batch_count >= self.schedule[-1][0]: - ans = self.schedule[-1][1] - maybe_print(ans) - return float(ans) - else: - cur_x, cur_y = self.schedule[0] - for i in range(1, len(self.schedule)): - next_x, next_y = self.schedule[i] - if batch_count >= cur_x and batch_count <= next_x: - ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) - maybe_print(ans) - return float(ans) - cur_x, cur_y = next_x, next_y - assert False - - -FloatLike = Union[float, ScheduledFloat] - - def _whitening_metric(x: Tensor, num_groups: int): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a973dd74b..a99ae4f18 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -32,6 +32,7 @@ from scaling import ( TanSwish, ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + LinearWithAuxLoss, Whiten, Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. penalize_abs_values_gt, @@ -1288,9 +1289,8 @@ class AttentionSqueeze(nn.Module): self.in_proj = nn.Linear(embed_dim, embed_dim, bias=False) - self.to_bottleneck_proj = nn.Linear(embed_dim, - bottleneck_dim, - bias=False) + self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim, + bottleneck_dim) # the main reason for this balancer is to keep the bottleneck activations in a "reasonable"