mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
use ActivationDropoutAndLinearFunction and swoosh kernel functions
This commit is contained in:
parent
73099da6be
commit
0b0732ae28
@ -20,6 +20,7 @@ from itertools import repeat
|
||||
from typing import Optional, Tuple, Union
|
||||
from functools import reduce
|
||||
import logging
|
||||
import k2
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import random
|
||||
import torch
|
||||
@ -1350,6 +1351,167 @@ class SwooshR(torch.nn.Module):
|
||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
return SwooshRFunction.apply(x)
|
||||
|
||||
|
||||
# simple version of SwooshL that does not redefine the backprop, used in
|
||||
# ActivationDropoutAndLinearFunction.
|
||||
def SwooshLForward(x: Tensor):
|
||||
x_offset = x - 4.0
|
||||
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
||||
log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum)
|
||||
return log_sum - 0.08 * x - 0.035
|
||||
|
||||
|
||||
# simple version of SwooshR that does not redefine the backprop, used in
|
||||
# ActivationDropoutAndLinearFunction.
|
||||
def SwooshRForward(x: Tensor):
|
||||
x_offset = x - 1.0
|
||||
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
||||
log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum)
|
||||
return log_sum - 0.08 * x - 0.313261687
|
||||
|
||||
|
||||
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx,
|
||||
x: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
activation: str,
|
||||
dropout_p: float,
|
||||
dropout_shared_dim: Optional[int]):
|
||||
if dropout_p != 0.0:
|
||||
dropout_shape = list(x.shape)
|
||||
if dropout_shared_dim is not None:
|
||||
dropout_shape[dropout_shared_dim] = 1
|
||||
# else it won't be very memory efficient.
|
||||
dropout_mask = ((1.0 / (1.0 - dropout_p)) *
|
||||
(torch.rand(*dropout_shape,
|
||||
device=x.device, dtype=x.dtype) > dropout_p))
|
||||
else:
|
||||
dropout_mask = None
|
||||
|
||||
ctx.save_for_backward(x, weight, bias, dropout_mask)
|
||||
|
||||
ctx.activation = activation
|
||||
|
||||
forward_activation_dict = {
|
||||
'SwooshL': k2.swoosh_l_forward,
|
||||
'SwooshR': k2.swoosh_r_forward
|
||||
}
|
||||
# it will raise a KeyError if this fails. This will be an error. We let it
|
||||
# propagate to the user.
|
||||
activation_func = forward_activation_dict[activation]
|
||||
x = activation_func(x)
|
||||
if dropout_mask is not None:
|
||||
x = x * dropout_mask
|
||||
x = torch.nn.functional.linear(x, weight, bias)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
saved = ctx.saved_tensors
|
||||
(x, weight, bias, dropout_mask) = saved
|
||||
|
||||
forward_and_deriv_activation_dict = {
|
||||
'SwooshL': k2.swoosh_l_forward_and_deriv,
|
||||
'SwooshR': k2.swoosh_r_forward_and_deriv
|
||||
}
|
||||
# the following lines a KeyError if the activation is unrecognized.
|
||||
# This will be an error. We let it propagate to the user.
|
||||
func = forward_and_deriv_activation_dict[ctx.activation]
|
||||
|
||||
y, func_deriv = func(x)
|
||||
if dropout_mask is not None:
|
||||
y = y * dropout_mask
|
||||
# now compute derivative of y w.r.t. weight and bias..
|
||||
# y: (..., in_channels), ans_grad: (..., out_channels),
|
||||
(out_channels, in_channels) = weight.shape
|
||||
|
||||
in_channels = y.shape[-1]
|
||||
g = ans_grad.reshape(-1, out_channels)
|
||||
weight_deriv = torch.matmul(g.t(),
|
||||
y.reshape(-1, in_channels))
|
||||
y_deriv = torch.matmul(ans_grad, weight)
|
||||
bias_deriv = None if bias is None else g.sum(dim=0)
|
||||
x_deriv = y_deriv * func_deriv
|
||||
if dropout_mask is not None:
|
||||
# order versus func_deriv does not matter
|
||||
x_deriv = x_deriv * dropout_mask
|
||||
|
||||
return x_deriv, weight_deriv, bias_deriv, None, None, None
|
||||
|
||||
|
||||
class ActivationDropoutAndLinear(torch.nn.Module):
|
||||
"""
|
||||
This merges an activation function followed by dropout and then a nn.Linear module;
|
||||
it does so in a memory efficient way so that it only stores the input to the whole
|
||||
module. If activation == SwooshL and dropout_shared_dim != None, this will be
|
||||
equivalent to:
|
||||
nn.Sequential(SwooshL(),
|
||||
Dropout3(dropout_p, shared_dim=dropout_shared_dim),
|
||||
ScaledLinear(in_channels, out_channels, bias=bias,
|
||||
initial_scale=initial_scale))
|
||||
If dropout_shared_dim is None, the dropout would be equivalent to
|
||||
Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
|
||||
mask is smaller.
|
||||
|
||||
Args:
|
||||
in_channels: number of input channels, e.g. 256
|
||||
out_channels: number of output channels, e.g. 256
|
||||
bias: if true, have a bias
|
||||
activation: the activation function, for now just support SwooshL.
|
||||
dropout_p: the dropout probability or schedule (happens after nonlinearity).
|
||||
dropout_shared_dim: the dimension, if any, across which the dropout mask is
|
||||
shared (e.g. the time dimension). If None, this may be less memory
|
||||
efficient if there are modules before this one that cache the input
|
||||
for their backprop (e.g. Balancer or Whiten).
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias: bool = True,
|
||||
activation: str = 'SwooshL',
|
||||
dropout_p: FloatLike = 0.0,
|
||||
dropout_shared_dim: Optional[int] = -1,
|
||||
initial_scale: float = 1.0):
|
||||
super().__init__()
|
||||
# create a temporary module of nn.Linear that we'll steal the
|
||||
# weights and bias from
|
||||
l = ScaledLinear(in_channels, out_channels,
|
||||
bias=bias,
|
||||
initial_scale=initial_scale)
|
||||
|
||||
self.weight = l.weight
|
||||
# register_parameter properly handles making it a parameter when l.bias
|
||||
# is None. I think there is some reason for doing it this way rather
|
||||
# than just setting it to None but I don't know what it is, maybe
|
||||
# something to do with exporting the module..
|
||||
self.register_parameter('bias', l.bias)
|
||||
|
||||
self.activation = activation
|
||||
self.dropout_p = dropout_p
|
||||
self.dropout_shared_dim = dropout_shared_dim
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
if torch.jit.is_scripting():
|
||||
if self.activation == 'SwooshL':
|
||||
x = SwooshLForward(x)
|
||||
elif self.activation == "SwooshR":
|
||||
x = SwooshRForward(x)
|
||||
else:
|
||||
assert False, self.activation
|
||||
return torch.nn.functional.linear(x,
|
||||
self.weight,
|
||||
self.bias)
|
||||
|
||||
return ActivationDropoutAndLinearFunction.apply(
|
||||
x, self.weight, self.bias, self.activation,
|
||||
float(self.dropout_p), self.dropout_shared_dim)
|
||||
|
||||
|
||||
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||
if num_channels <= x.shape[-1]:
|
||||
return x[..., :num_channels]
|
||||
@ -1360,8 +1522,6 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||
return torch.cat((x, zeros), dim=-1)
|
||||
|
||||
|
||||
|
||||
|
||||
def _test_whiten():
|
||||
for proportion in [0.1, 0.5, 10.0]:
|
||||
logging.info(f"_test_whiten(): proportion = {proportion}")
|
||||
@ -1391,8 +1551,6 @@ def _test_whiten():
|
||||
assert not torch.allclose(x.grad, y_grad)
|
||||
|
||||
|
||||
|
||||
|
||||
def _test_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
@ -1541,8 +1699,6 @@ def _test_caching_eval():
|
||||
assert torch.allclose(m[2].weight.grad, weight_grad1b)
|
||||
|
||||
|
||||
|
||||
|
||||
def _test_piecewise_linear():
|
||||
p = PiecewiseLinear( (0, 10.0) )
|
||||
for x in [-100, 0, 100]:
|
||||
@ -1571,6 +1727,64 @@ def _test_piecewise_linear():
|
||||
assert abs(y1 - y2) < 0.001
|
||||
|
||||
|
||||
def _test_activation_dropout_and_linear():
|
||||
in_channels = 20
|
||||
out_channels = 30
|
||||
|
||||
for bias in [True, False]:
|
||||
# actually we don't test for dropout_p != 0.0 because forward functions will give
|
||||
# different answers. This is because
|
||||
for dropout_p in [0.0, 0.1]:
|
||||
for activation in ['SwooshL', 'SwooshR']:
|
||||
m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(),
|
||||
Dropout3(p=dropout_p, shared_dim=-1),
|
||||
ScaledLinear(in_channels, out_channels, bias=bias,
|
||||
initial_scale=0.5))
|
||||
m2 = ActivationDropoutAndLinear(in_channels, out_channels,
|
||||
bias=bias, initial_scale=0.5,
|
||||
activation=activation,
|
||||
dropout_p=dropout_p)
|
||||
with torch.no_grad():
|
||||
m2.weight[:] = m1[2].weight
|
||||
if bias:
|
||||
m2.bias[:] = m1[2].bias
|
||||
# make sure forward gives same result.
|
||||
x1 = torch.randn(10, in_channels)
|
||||
x1.requires_grad = True
|
||||
|
||||
# TEMP.
|
||||
assert torch.allclose(SwooshRFunction.apply(x1),
|
||||
SwooshRForward(x1),
|
||||
atol=1.0e-03)
|
||||
|
||||
x2 = x1.clone().detach()
|
||||
x2.requires_grad = True
|
||||
seed = 10
|
||||
torch.manual_seed(seed)
|
||||
y1 = m1(x1)
|
||||
y_grad = torch.randn_like(y1)
|
||||
y1.backward(gradient=y_grad)
|
||||
torch.manual_seed(seed)
|
||||
y2 = m2(x2)
|
||||
y2.backward(gradient=y_grad)
|
||||
|
||||
print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}")
|
||||
print("y1 = ", y1)
|
||||
print("y2 = ", y2)
|
||||
assert torch.allclose(y1, y2, atol=0.02)
|
||||
assert torch.allclose(m1[2].weight.grad, m2.weight.grad,
|
||||
atol=1.0e-05)
|
||||
if bias:
|
||||
assert torch.allclose(m1[2].bias.grad, m2.bias.grad,
|
||||
atol=1.0e-05)
|
||||
print("x1.grad = ", x1.grad)
|
||||
print("x2.grad = ", x2.grad)
|
||||
def isclose(a, b):
|
||||
# return true if cosine similarity is > 0.9.
|
||||
return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt()
|
||||
# the SwooshL() implementation has a noisy gradient due to 1-byte
|
||||
# storage of it.
|
||||
assert isclose(x1.grad, x2.grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -1586,3 +1800,4 @@ if __name__ == "__main__":
|
||||
_test_double_swish_deriv()
|
||||
_test_swooshr_deriv()
|
||||
_test_swooshl_deriv()
|
||||
_test_activation_dropout_and_linear()
|
||||
|
||||
@ -32,6 +32,7 @@ from scaling import (
|
||||
SwooshL,
|
||||
SwooshR,
|
||||
ChunkCausalDepthwiseConv1d,
|
||||
ActivationDropoutAndLinear,
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
@ -435,6 +436,8 @@ class Zipformer2(EncoderInterface):
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
lengths = (lengths + 1) // 2
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -1460,7 +1463,6 @@ class SelfAttention(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Zipformer2 model.
|
||||
"""
|
||||
@ -1477,10 +1479,12 @@ class FeedforwardModule(nn.Module):
|
||||
max_positive=1.0,
|
||||
min_abs=0.75,
|
||||
max_abs=5.0)
|
||||
self.activation = SwooshL()
|
||||
|
||||
# shared_dim=0 means we share the dropout mask along the time axis
|
||||
self.dropout = Dropout3(dropout, shared_dim=0)
|
||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||
self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim,
|
||||
activation='SwooshL',
|
||||
dropout_p=dropout,
|
||||
dropout_shared_dim=0, bias=True,
|
||||
initial_scale=0.1)
|
||||
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
@ -1492,8 +1496,7 @@ class FeedforwardModule(nn.Module):
|
||||
x: Tensor):
|
||||
x = self.in_proj(x)
|
||||
x = self.hidden_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
# out_proj contains SwooshL activation, then dropout, then linear.
|
||||
x = self.out_proj(x)
|
||||
x = self.out_whiten(x)
|
||||
return x
|
||||
@ -1670,7 +1673,6 @@ class ConvolutionModule(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
|
||||
|
||||
self.balancer2 = Balancer(
|
||||
bottleneck_dim, channel_dim=1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
@ -1679,19 +1681,16 @@ class ConvolutionModule(nn.Module):
|
||||
max_abs=10.0,
|
||||
)
|
||||
|
||||
self.activation3 = SwooshR()
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
bottleneck_dim, channels,
|
||||
initial_scale=0.05,
|
||||
self.out_proj = ActivationDropoutAndLinear(
|
||||
bottleneck_dim, channels, activation='SwooshR',
|
||||
dropout_p=0.0, initial_scale=0.05,
|
||||
)
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
@ -1724,7 +1723,7 @@ class ConvolutionModule(nn.Module):
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
if chunk_size >= 0:
|
||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||
@ -1735,7 +1734,6 @@ class ConvolutionModule(nn.Module):
|
||||
x = self.balancer2(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
x = self.activation3(x)
|
||||
x = self.whiten(x) # (time, batch, channels)
|
||||
x = self.out_proj(x) # (time, batch, channels)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user