2024-10-17 00:09:54 +00:00

488 lines
16 KiB
Python

# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import random
from typing import Optional, Tuple, Union
import k2
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
class BiasNorm(torch.nn.Module):
"""
This is intended to be a simpler replacement for LayerNorm. The observation this is based on,
is that Transformer-type networks, especially with pre-normalization, sometimes seem to set one
of the channel dimensions to a large constant value (e.g. 50), which "defeats" the LayerNorm
because the output magnitude is then not strongly dependent on the other (useful) channels.
Presumably the weight and bias of the LayerNorm are required to allow it to do this. Instead,
we give the BiasNorm a trainable bias that it can use when computing the scale for
normalization. We also give it a scalar trainable scale on the output.
"""
def __init__(self, num_channels: int, device: torch.device) -> None:
"""
BiasNorm initialization.
Parameters
----------
num_channels : int
The number of input channels.
device : torch.device
The device used to store the layer weights.
Either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.scale = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
self.bias = torch.nn.Parameter(
torch.zeros(num_channels, dtype=torch.float32, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the BiasNorm module.
Parameters
----------
x : torch.Tensor[torch.float32]
A float tensor of shape (1, seq_len, num_channels). The module input.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of shape (1, seq_len, num_channels).
A normalized output tensor of the same shape as input.
"""
return x * self.scale / torch.mean((x - self.bias)**2, dim=2, keepdim=True)**0.5
class ChunkCausalDepthwiseConv1d(torch.nn.Module):
"""
Behaves like a depthwise 1D convolution, except that it is causal in a chunkwise way, as if we
had a block-triangular attention mask.The chunk size is provided at test time, it should be kept
in sync with the attention mask.
This has a little more than twice the parameters of a conventional depthwise conv1d module:
we implement it by having one depthwise convolution, of half the width, that is causal (via
right padding), and one depthwise convolution that is applied only within chunks, that we
multiply by a scaling factor which depends on the position within the chunk.
"""
def __init__(
self, num_channels: int, kernel_size: int, right_context: int, device: torch.device,
) -> None:
"""
ChunkCausalDepthwiseConv1d initialization.
Parameters
----------
num_channels : int
The number of input channels.
kernel_size : int
The kernel size for chunkwise convolution. The causal convolution kernel size is
the half of this original value. Should be an odd number.
right_context : int
The module look ahead future context, used to update module left cache correctly.
device : torch.device
The device used to store the layer weights.
Either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
if kernel_size % 2 == 0:
raise ValueError(
'Kernel size for ChunkCausalDepthwiseConv1d convolution '
f'module should be an odd number, but got {kernel_size}.',
)
self.kernel_size = kernel_size
self.right_context = right_context
self.causal_conv = torch.nn.Conv1d(
num_channels, num_channels, (kernel_size + 1) // 2, groups=num_channels, device=device,
)
self.chunkwise_conv = torch.nn.Conv1d(
num_channels,
num_channels,
kernel_size,
padding=kernel_size // 2,
groups=num_channels,
device=device,
)
# First row is correction factors added to the scale near the left edge of the chunk,
# second row is correction factors added to the scale near the right edge of the chunk,
# both of these are added to a default scale of 1.0.
self.chunkwise_conv_scale = torch.nn.Parameter(
torch.zeros(2, num_channels, kernel_size, dtype=torch.float32, device=device),
)
def forward(
self, x: torch.Tensor, left_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the ChunkCausalDepthwiseConv1d module. Returns processed tensor
of the same shape as input and updated cached convolution tensor of the left context.
Parameters
----------
x : torch.Tensor[torch.float32]
The input float tensor of shape (1, num_channels, seq_len). The module input.
left_cache : torch.Tensor[torch.float32]
A cached convolution tensor of the left context
of shape (1, num_channels, left_cache_len).
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- module output of shape (1, num_channels, seq_len).
A tensor with the same shape as input x.
- updated cached convolution tensor of the left context
of shape (1, num_channels, left_cache_len).
"""
seq_len = x.size(2)
x_chunk = self.chunkwise_conv(x) # does not change shape
x = torch.cat((left_cache, x), dim=2) # Pad with left cache
left_cache = x[
:,
:,
x.size(2) - self.right_context - left_cache.size(2):
x.size(2) - self.right_context,
] # Update cache
x_causal = self.causal_conv(x)
if seq_len < self.kernel_size:
left_edge = self.chunkwise_conv_scale[:1, :, :seq_len]
right_edge = self.chunkwise_conv_scale[1:, :, self.kernel_size - seq_len:]
else:
pad = torch.zeros(
1, self.chunkwise_conv_scale.size(1), seq_len - self.kernel_size,
dtype=torch.float32,
device=self.chunkwise_conv_scale.device,
)
left_edge = torch.cat((self.chunkwise_conv_scale[:1], pad), dim=2)
right_edge = torch.cat((pad, self.chunkwise_conv_scale[1:]), dim=2)
chunk_scale = 1.0 + left_edge + right_edge
x = x_chunk * chunk_scale + x_causal
return x, left_cache
class SwooshL(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass and returns Swoosh-L activation.
Parameters
----------
x : torch.Tensor[torch.float32]
A float tensor of an arbitrary shape (*). The module input.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of an arbitrary shape (*). A Swoosh-L activation output tensor
of the same shape as input x.
"""
logaddexp = torch.clamp(x - 4.0, min=0.0) + torch.log1p(torch.exp(-torch.abs(x - 4.0)))
return logaddexp - 0.08 * x - 0.035
class SwooshR(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass and returns Swoosh-R activation.
Parameters
----------
x : torch.Tensor[torch.float32]
A float tensor of an arbitrary shape (*). The module input.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of an arbitrary shape (*). A Swoosh-R activation output tensor
of the same shape as input x.
"""
logaddexp = torch.clamp(x - 1.0, min=0.0) + torch.log1p(torch.exp(-torch.abs(x - 1.0)))
return logaddexp - 0.08 * x - 0.313261687
def _test_whiten():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"_test_whiten(): proportion = {proportion}")
x = torch.randn(100, 128)
direction = torch.randn(128)
coeffs = torch.randn(100, 1)
x += proportion * direction * coeffs
x.requires_grad = True
m = Whiten(
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
) # grad_scale
for _ in range(4):
y = m(x)
y_grad = torch.randn_like(x)
y.backward(gradient=y_grad)
if proportion < 0.2:
assert torch.allclose(x.grad, y_grad)
elif proportion > 1.0:
assert not torch.allclose(x.grad, y_grad)
def _test_balancer_sign():
probs = torch.arange(0, 1, 0.01)
N = 1000
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
x = x.detach()
x.requires_grad = True
m = Balancer(
probs.numel(),
channel_dim=0,
min_positive=0.05,
max_positive=0.95,
min_abs=0.0,
prob=1.0,
)
y_grad = torch.sign(torch.randn(probs.numel(), N))
y = m(x)
y.backward(gradient=y_grad)
print("_test_balancer_sign: x = ", x)
print("_test_balancer_sign: y grad = ", y_grad)
print("_test_balancer_sign: x grad = ", x.grad)
def _test_balancer_magnitude():
magnitudes = torch.arange(0, 1, 0.01)
N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = x.detach()
x.requires_grad = True
m = Balancer(
magnitudes.numel(),
channel_dim=0,
min_positive=0.0,
max_positive=1.0,
min_abs=0.2,
max_abs=0.7,
prob=1.0,
)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
y = m(x)
y.backward(gradient=y_grad)
print("_test_balancer_magnitude: x = ", x)
print("_test_balancer_magnitude: y grad = ", y_grad)
print("_test_balancer_magnitude: x grad = ", x.grad)
def _test_double_swish_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0
x.requires_grad = True
m = DoubleSwish()
tol = (1.2 - (-0.043637)) / 255.0
torch.autograd.gradcheck(m, x, atol=tol)
# for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
x.requires_grad = True
y = m(x)
def _test_swooshl_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0
x.requires_grad = True
m = SwooshL()
tol = 1.0 / 255.0
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
# for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
x.requires_grad = True
y = m(x)
def _test_swooshr_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0
x.requires_grad = True
m = SwooshR()
tol = 1.0 / 255.0
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
# for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
x.requires_grad = True
y = m(x)
def _test_softmax():
a = torch.randn(2, 10, dtype=torch.float64)
b = a.clone()
a.requires_grad = True
b.requires_grad = True
a.softmax(dim=1)[:, 0].sum().backward()
print("a grad = ", a.grad)
softmax(b, dim=1)[:, 0].sum().backward()
print("b grad = ", b.grad)
assert torch.allclose(a.grad, b.grad)
def _test_piecewise_linear():
p = PiecewiseLinear((0, 10.0))
for x in [-100, 0, 100]:
assert p(x) == 10.0
p = PiecewiseLinear((0, 10.0), (1, 0.0))
for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
print("x, y = ", x, y)
assert p(x) == y, (x, p(x), y)
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
pq = p.max(q)
for x in x_vals:
y1 = max(p(x), q(x))
y2 = pq(x)
assert abs(y1 - y2) < 0.001
pq = p.min(q)
for x in x_vals:
y1 = min(p(x), q(x))
y2 = pq(x)
assert abs(y1 - y2) < 0.001
pq = p + q
for x in x_vals:
y1 = p(x) + q(x)
y2 = pq(x)
assert abs(y1 - y2) < 0.001
def _test_activation_dropout_and_linear():
in_channels = 20
out_channels = 30
for bias in [True, False]:
# actually we don't test for dropout_p != 0.0 because forward functions will give
# different answers. This is because we are using the k2 implementation of
# swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
# internally, messing up the random state.
for dropout_p in [0.0]:
for activation in ["SwooshL", "SwooshR"]:
m1 = nn.Sequential(
SwooshL() if activation == "SwooshL" else SwooshR(),
Dropout3(p=dropout_p, shared_dim=-1),
ScaledLinear(
in_channels, out_channels, bias=bias, initial_scale=0.5
),
)
m2 = ActivationDropoutAndLinear(
in_channels,
out_channels,
bias=bias,
initial_scale=0.5,
activation=activation,
dropout_p=dropout_p,
)
with torch.no_grad():
m2.weight[:] = m1[2].weight
if bias:
m2.bias[:] = m1[2].bias
# make sure forward gives same result.
x1 = torch.randn(10, in_channels)
x1.requires_grad = True
# TEMP.
assert torch.allclose(
SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
)
x2 = x1.clone().detach()
x2.requires_grad = True
seed = 10
torch.manual_seed(seed)
y1 = m1(x1)
y_grad = torch.randn_like(y1)
y1.backward(gradient=y_grad)
torch.manual_seed(seed)
y2 = m2(x2)
y2.backward(gradient=y_grad)
print(
f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
)
print("y1 = ", y1)
print("y2 = ", y2)
assert torch.allclose(y1, y2, atol=0.02)
assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
if bias:
assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
print("x1.grad = ", x1.grad)
print("x2.grad = ", x2.grad)
def isclose(a, b):
# return true if cosine similarity is > 0.9.
return (a * b).sum() > 0.9 * (
(a**2).sum() * (b**2).sum()
).sqrt()
# the SwooshL() implementation has a noisy gradient due to 1-byte
# storage of it.
assert isclose(x1.grad, x2.grad)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_piecewise_linear()
_test_softmax()
_test_whiten()
_test_balancer_sign()
_test_balancer_magnitude()
_test_double_swish_deriv()
_test_swooshr_deriv()
_test_swooshl_deriv()
_test_activation_dropout_and_linear()