diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index dc0263d1f..6e6edf5f4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -20,6 +20,7 @@ from itertools import repeat from typing import Optional, Tuple, Union from functools import reduce import logging +import k2 from torch.cuda.amp import custom_fwd, custom_bwd import random import torch @@ -1350,6 +1351,167 @@ class SwooshR(torch.nn.Module): return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 return SwooshRFunction.apply(x) + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int]): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = ((1.0 / (1.0 - dropout_p)) * + (torch.rand(*dropout_shape, + device=x.device, dtype=x.dtype) > dropout_p)) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + 'SwooshL': k2.swoosh_l_forward, + 'SwooshR': k2.swoosh_r_forward + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + 'SwooshL': k2.swoosh_l_forward_and_deriv, + 'SwooshR': k2.swoosh_r_forward_and_deriv + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), + y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + def __init__(self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = 'SwooshL', + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear(in_channels, out_channels, + bias=bias, + initial_scale=initial_scale) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter('bias', l.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, + x: Tensor): + if torch.jit.is_scripting(): + if self.activation == 'SwooshL': + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, + self.weight, + self.bias) + + return ActivationDropoutAndLinearFunction.apply( + x, self.weight, self.bias, self.activation, + float(self.dropout_p), self.dropout_shared_dim) + + def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: if num_channels <= x.shape[-1]: return x[..., :num_channels] @@ -1360,8 +1522,6 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: return torch.cat((x, zeros), dim=-1) - - def _test_whiten(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"_test_whiten(): proportion = {proportion}") @@ -1391,8 +1551,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - - def _test_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1541,8 +1699,6 @@ def _test_caching_eval(): assert torch.allclose(m[2].weight.grad, weight_grad1b) - - def _test_piecewise_linear(): p = PiecewiseLinear( (0, 10.0) ) for x in [-100, 0, 100]: @@ -1571,6 +1727,64 @@ def _test_piecewise_linear(): assert abs(y1 - y2) < 0.001 +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because + for dropout_p in [0.0, 0.1]: + for activation in ['SwooshL', 'SwooshR']: + m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=0.5)) + m2 = ActivationDropoutAndLinear(in_channels, out_channels, + bias=bias, initial_scale=0.5, + activation=activation, + dropout_p=dropout_p) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose(SwooshRFunction.apply(x1), + SwooshRForward(x1), + atol=1.0e-03) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, + atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, + atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) if __name__ == "__main__": @@ -1586,3 +1800,4 @@ if __name__ == "__main__": _test_double_swish_deriv() _test_swooshr_deriv() _test_swooshl_deriv() + _test_activation_dropout_and_linear() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f270a48fc..06dabe420 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -32,6 +32,7 @@ from scaling import ( SwooshL, SwooshR, ChunkCausalDepthwiseConv1d, + ActivationDropoutAndLinear, ScaledConv1d, ScaledConv2d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. @@ -435,7 +436,9 @@ class Zipformer2(EncoderInterface): x = self.downsample_output(x) # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2 - lengths = (lengths + 1) // 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (lengths + 1) // 2 x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -1460,7 +1463,6 @@ class SelfAttention(nn.Module): return x - class FeedforwardModule(nn.Module): """Feedforward module in Zipformer2 model. """ @@ -1477,11 +1479,13 @@ class FeedforwardModule(nn.Module): max_positive=1.0, min_abs=0.75, max_abs=5.0) - self.activation = SwooshL() + # shared_dim=0 means we share the dropout mask along the time axis - self.dropout = Dropout3(dropout, shared_dim=0) - self.out_proj = ScaledLinear(feedforward_dim, embed_dim, - initial_scale=0.1) + self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, + activation='SwooshL', + dropout_p=dropout, + dropout_shared_dim=0, bias=True, + initial_scale=0.1) self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), @@ -1492,8 +1496,7 @@ class FeedforwardModule(nn.Module): x: Tensor): x = self.in_proj(x) x = self.hidden_balancer(x) - x = self.activation(x) - x = self.dropout(x) + # out_proj contains SwooshL activation, then dropout, then linear. x = self.out_proj(x) x = self.out_whiten(x) return x @@ -1670,7 +1673,6 @@ class ConvolutionModule(nn.Module): kernel_size=kernel_size, padding=kernel_size // 2) - self.balancer2 = Balancer( bottleneck_dim, channel_dim=1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), @@ -1679,19 +1681,16 @@ class ConvolutionModule(nn.Module): max_abs=10.0, ) - self.activation3 = SwooshR() - self.whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), grad_scale=0.01) - self.out_proj = ScaledLinear( - bottleneck_dim, channels, - initial_scale=0.05, + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, channels, activation='SwooshR', + dropout_p=0.0, initial_scale=0.05, ) - def forward(self, x: Tensor, src_key_padding_mask: Optional[Tensor] = None, @@ -1724,7 +1723,7 @@ class ConvolutionModule(nn.Module): x = x.permute(1, 2, 0) # (#batch, channels, time). if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) if chunk_size >= 0: assert self.causal, "Must initialize model with causal=True if you use chunk_size" @@ -1735,7 +1734,6 @@ class ConvolutionModule(nn.Module): x = self.balancer2(x) x = x.permute(2, 0, 1) # (time, batch, channels) - x = self.activation3(x) x = self.whiten(x) # (time, batch, channels) x = self.out_proj(x) # (time, batch, channels)