use ActivationDropoutAndLinearFunction and swoosh kernel functions

This commit is contained in:
yaozengwei 2023-04-12 19:11:26 +08:00
parent 73099da6be
commit 0b0732ae28
2 changed files with 236 additions and 23 deletions

View File

@ -20,6 +20,7 @@ from itertools import repeat
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from functools import reduce from functools import reduce
import logging import logging
import k2
from torch.cuda.amp import custom_fwd, custom_bwd from torch.cuda.amp import custom_fwd, custom_bwd
import random import random
import torch import torch
@ -1350,6 +1351,167 @@ class SwooshR(torch.nn.Module):
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return SwooshRFunction.apply(x) return SwooshRFunction.apply(x)
# simple version of SwooshL that does not redefine the backprop, used in
# ActivationDropoutAndLinearFunction.
def SwooshLForward(x: Tensor):
x_offset = x - 4.0
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum)
return log_sum - 0.08 * x - 0.035
# simple version of SwooshR that does not redefine the backprop, used in
# ActivationDropoutAndLinearFunction.
def SwooshRForward(x: Tensor):
x_offset = x - 1.0
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum)
return log_sum - 0.08 * x - 0.313261687
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx,
x: Tensor,
weight: Tensor,
bias: Optional[Tensor],
activation: str,
dropout_p: float,
dropout_shared_dim: Optional[int]):
if dropout_p != 0.0:
dropout_shape = list(x.shape)
if dropout_shared_dim is not None:
dropout_shape[dropout_shared_dim] = 1
# else it won't be very memory efficient.
dropout_mask = ((1.0 / (1.0 - dropout_p)) *
(torch.rand(*dropout_shape,
device=x.device, dtype=x.dtype) > dropout_p))
else:
dropout_mask = None
ctx.save_for_backward(x, weight, bias, dropout_mask)
ctx.activation = activation
forward_activation_dict = {
'SwooshL': k2.swoosh_l_forward,
'SwooshR': k2.swoosh_r_forward
}
# it will raise a KeyError if this fails. This will be an error. We let it
# propagate to the user.
activation_func = forward_activation_dict[activation]
x = activation_func(x)
if dropout_mask is not None:
x = x * dropout_mask
x = torch.nn.functional.linear(x, weight, bias)
return x
@staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor):
saved = ctx.saved_tensors
(x, weight, bias, dropout_mask) = saved
forward_and_deriv_activation_dict = {
'SwooshL': k2.swoosh_l_forward_and_deriv,
'SwooshR': k2.swoosh_r_forward_and_deriv
}
# the following lines a KeyError if the activation is unrecognized.
# This will be an error. We let it propagate to the user.
func = forward_and_deriv_activation_dict[ctx.activation]
y, func_deriv = func(x)
if dropout_mask is not None:
y = y * dropout_mask
# now compute derivative of y w.r.t. weight and bias..
# y: (..., in_channels), ans_grad: (..., out_channels),
(out_channels, in_channels) = weight.shape
in_channels = y.shape[-1]
g = ans_grad.reshape(-1, out_channels)
weight_deriv = torch.matmul(g.t(),
y.reshape(-1, in_channels))
y_deriv = torch.matmul(ans_grad, weight)
bias_deriv = None if bias is None else g.sum(dim=0)
x_deriv = y_deriv * func_deriv
if dropout_mask is not None:
# order versus func_deriv does not matter
x_deriv = x_deriv * dropout_mask
return x_deriv, weight_deriv, bias_deriv, None, None, None
class ActivationDropoutAndLinear(torch.nn.Module):
"""
This merges an activation function followed by dropout and then a nn.Linear module;
it does so in a memory efficient way so that it only stores the input to the whole
module. If activation == SwooshL and dropout_shared_dim != None, this will be
equivalent to:
nn.Sequential(SwooshL(),
Dropout3(dropout_p, shared_dim=dropout_shared_dim),
ScaledLinear(in_channels, out_channels, bias=bias,
initial_scale=initial_scale))
If dropout_shared_dim is None, the dropout would be equivalent to
Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
mask is smaller.
Args:
in_channels: number of input channels, e.g. 256
out_channels: number of output channels, e.g. 256
bias: if true, have a bias
activation: the activation function, for now just support SwooshL.
dropout_p: the dropout probability or schedule (happens after nonlinearity).
dropout_shared_dim: the dimension, if any, across which the dropout mask is
shared (e.g. the time dimension). If None, this may be less memory
efficient if there are modules before this one that cache the input
for their backprop (e.g. Balancer or Whiten).
"""
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
activation: str = 'SwooshL',
dropout_p: FloatLike = 0.0,
dropout_shared_dim: Optional[int] = -1,
initial_scale: float = 1.0):
super().__init__()
# create a temporary module of nn.Linear that we'll steal the
# weights and bias from
l = ScaledLinear(in_channels, out_channels,
bias=bias,
initial_scale=initial_scale)
self.weight = l.weight
# register_parameter properly handles making it a parameter when l.bias
# is None. I think there is some reason for doing it this way rather
# than just setting it to None but I don't know what it is, maybe
# something to do with exporting the module..
self.register_parameter('bias', l.bias)
self.activation = activation
self.dropout_p = dropout_p
self.dropout_shared_dim = dropout_shared_dim
def forward(self,
x: Tensor):
if torch.jit.is_scripting():
if self.activation == 'SwooshL':
x = SwooshLForward(x)
elif self.activation == "SwooshR":
x = SwooshRForward(x)
else:
assert False, self.activation
return torch.nn.functional.linear(x,
self.weight,
self.bias)
return ActivationDropoutAndLinearFunction.apply(
x, self.weight, self.bias, self.activation,
float(self.dropout_p), self.dropout_shared_dim)
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
if num_channels <= x.shape[-1]: if num_channels <= x.shape[-1]:
return x[..., :num_channels] return x[..., :num_channels]
@ -1360,8 +1522,6 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
return torch.cat((x, zeros), dim=-1) return torch.cat((x, zeros), dim=-1)
def _test_whiten(): def _test_whiten():
for proportion in [0.1, 0.5, 10.0]: for proportion in [0.1, 0.5, 10.0]:
logging.info(f"_test_whiten(): proportion = {proportion}") logging.info(f"_test_whiten(): proportion = {proportion}")
@ -1391,8 +1551,6 @@ def _test_whiten():
assert not torch.allclose(x.grad, y_grad) assert not torch.allclose(x.grad, y_grad)
def _test_balancer_sign(): def _test_balancer_sign():
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
@ -1541,8 +1699,6 @@ def _test_caching_eval():
assert torch.allclose(m[2].weight.grad, weight_grad1b) assert torch.allclose(m[2].weight.grad, weight_grad1b)
def _test_piecewise_linear(): def _test_piecewise_linear():
p = PiecewiseLinear( (0, 10.0) ) p = PiecewiseLinear( (0, 10.0) )
for x in [-100, 0, 100]: for x in [-100, 0, 100]:
@ -1571,6 +1727,64 @@ def _test_piecewise_linear():
assert abs(y1 - y2) < 0.001 assert abs(y1 - y2) < 0.001
def _test_activation_dropout_and_linear():
in_channels = 20
out_channels = 30
for bias in [True, False]:
# actually we don't test for dropout_p != 0.0 because forward functions will give
# different answers. This is because
for dropout_p in [0.0, 0.1]:
for activation in ['SwooshL', 'SwooshR']:
m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(),
Dropout3(p=dropout_p, shared_dim=-1),
ScaledLinear(in_channels, out_channels, bias=bias,
initial_scale=0.5))
m2 = ActivationDropoutAndLinear(in_channels, out_channels,
bias=bias, initial_scale=0.5,
activation=activation,
dropout_p=dropout_p)
with torch.no_grad():
m2.weight[:] = m1[2].weight
if bias:
m2.bias[:] = m1[2].bias
# make sure forward gives same result.
x1 = torch.randn(10, in_channels)
x1.requires_grad = True
# TEMP.
assert torch.allclose(SwooshRFunction.apply(x1),
SwooshRForward(x1),
atol=1.0e-03)
x2 = x1.clone().detach()
x2.requires_grad = True
seed = 10
torch.manual_seed(seed)
y1 = m1(x1)
y_grad = torch.randn_like(y1)
y1.backward(gradient=y_grad)
torch.manual_seed(seed)
y2 = m2(x2)
y2.backward(gradient=y_grad)
print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}")
print("y1 = ", y1)
print("y2 = ", y2)
assert torch.allclose(y1, y2, atol=0.02)
assert torch.allclose(m1[2].weight.grad, m2.weight.grad,
atol=1.0e-05)
if bias:
assert torch.allclose(m1[2].bias.grad, m2.bias.grad,
atol=1.0e-05)
print("x1.grad = ", x1.grad)
print("x2.grad = ", x2.grad)
def isclose(a, b):
# return true if cosine similarity is > 0.9.
return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt()
# the SwooshL() implementation has a noisy gradient due to 1-byte
# storage of it.
assert isclose(x1.grad, x2.grad)
if __name__ == "__main__": if __name__ == "__main__":
@ -1586,3 +1800,4 @@ if __name__ == "__main__":
_test_double_swish_deriv() _test_double_swish_deriv()
_test_swooshr_deriv() _test_swooshr_deriv()
_test_swooshl_deriv() _test_swooshl_deriv()
_test_activation_dropout_and_linear()

View File

@ -32,6 +32,7 @@ from scaling import (
SwooshL, SwooshL,
SwooshR, SwooshR,
ChunkCausalDepthwiseConv1d, ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledConv1d, ScaledConv1d,
ScaledConv2d, ScaledConv2d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
@ -435,7 +436,9 @@ class Zipformer2(EncoderInterface):
x = self.downsample_output(x) x = self.downsample_output(x)
# class Downsample has this rounding behavior.. # class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2 assert self.output_downsampling_factor == 2
lengths = (lengths + 1) // 2 with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (lengths + 1) // 2
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -1460,7 +1463,6 @@ class SelfAttention(nn.Module):
return x return x
class FeedforwardModule(nn.Module): class FeedforwardModule(nn.Module):
"""Feedforward module in Zipformer2 model. """Feedforward module in Zipformer2 model.
""" """
@ -1477,11 +1479,13 @@ class FeedforwardModule(nn.Module):
max_positive=1.0, max_positive=1.0,
min_abs=0.75, min_abs=0.75,
max_abs=5.0) max_abs=5.0)
self.activation = SwooshL()
# shared_dim=0 means we share the dropout mask along the time axis # shared_dim=0 means we share the dropout mask along the time axis
self.dropout = Dropout3(dropout, shared_dim=0) self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim,
self.out_proj = ScaledLinear(feedforward_dim, embed_dim, activation='SwooshL',
initial_scale=0.1) dropout_p=dropout,
dropout_shared_dim=0, bias=True,
initial_scale=0.1)
self.out_whiten = Whiten(num_groups=1, self.out_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5), whitening_limit=_whitening_schedule(7.5),
@ -1492,8 +1496,7 @@ class FeedforwardModule(nn.Module):
x: Tensor): x: Tensor):
x = self.in_proj(x) x = self.in_proj(x)
x = self.hidden_balancer(x) x = self.hidden_balancer(x)
x = self.activation(x) # out_proj contains SwooshL activation, then dropout, then linear.
x = self.dropout(x)
x = self.out_proj(x) x = self.out_proj(x)
x = self.out_whiten(x) x = self.out_whiten(x)
return x return x
@ -1670,7 +1673,6 @@ class ConvolutionModule(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
padding=kernel_size // 2) padding=kernel_size // 2)
self.balancer2 = Balancer( self.balancer2 = Balancer(
bottleneck_dim, channel_dim=1, bottleneck_dim, channel_dim=1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
@ -1679,19 +1681,16 @@ class ConvolutionModule(nn.Module):
max_abs=10.0, max_abs=10.0,
) )
self.activation3 = SwooshR()
self.whiten = Whiten(num_groups=1, self.whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5), whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
self.out_proj = ScaledLinear( self.out_proj = ActivationDropoutAndLinear(
bottleneck_dim, channels, bottleneck_dim, channels, activation='SwooshR',
initial_scale=0.05, dropout_p=0.0, initial_scale=0.05,
) )
def forward(self, def forward(self,
x: Tensor, x: Tensor,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
@ -1724,7 +1723,7 @@ class ConvolutionModule(nn.Module):
x = x.permute(1, 2, 0) # (#batch, channels, time). x = x.permute(1, 2, 0) # (#batch, channels, time).
if src_key_padding_mask is not None: if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if chunk_size >= 0: if chunk_size >= 0:
assert self.causal, "Must initialize model with causal=True if you use chunk_size" assert self.causal, "Must initialize model with causal=True if you use chunk_size"
@ -1735,7 +1734,6 @@ class ConvolutionModule(nn.Module):
x = self.balancer2(x) x = self.balancer2(x)
x = x.permute(2, 0, 1) # (time, batch, channels) x = x.permute(2, 0, 1) # (time, batch, channels)
x = self.activation3(x)
x = self.whiten(x) # (time, batch, channels) x = self.whiten(x) # (time, batch, channels)
x = self.out_proj(x) # (time, batch, channels) x = self.out_proj(x) # (time, batch, channels)