mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
488 lines
16 KiB
Python
488 lines
16 KiB
Python
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import logging
|
|
import math
|
|
import random
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import k2
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
|
|
class BiasNorm(torch.nn.Module):
|
|
"""
|
|
This is intended to be a simpler replacement for LayerNorm. The observation this is based on,
|
|
is that Transformer-type networks, especially with pre-normalization, sometimes seem to set one
|
|
of the channel dimensions to a large constant value (e.g. 50), which "defeats" the LayerNorm
|
|
because the output magnitude is then not strongly dependent on the other (useful) channels.
|
|
Presumably the weight and bias of the LayerNorm are required to allow it to do this. Instead,
|
|
we give the BiasNorm a trainable bias that it can use when computing the scale for
|
|
normalization. We also give it a scalar trainable scale on the output.
|
|
"""
|
|
|
|
def __init__(self, num_channels: int, device: torch.device) -> None:
|
|
"""
|
|
BiasNorm initialization.
|
|
|
|
Parameters
|
|
----------
|
|
num_channels : int
|
|
The number of input channels.
|
|
device : torch.device
|
|
The device used to store the layer weights.
|
|
Either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.scale = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
|
|
self.bias = torch.nn.Parameter(
|
|
torch.zeros(num_channels, dtype=torch.float32, device=device),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass of the BiasNorm module.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
A float tensor of shape (1, seq_len, num_channels). The module input.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of shape (1, seq_len, num_channels).
|
|
A normalized output tensor of the same shape as input.
|
|
"""
|
|
|
|
return x * self.scale / torch.mean((x - self.bias)**2, dim=2, keepdim=True)**0.5
|
|
|
|
|
|
class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
|
"""
|
|
Behaves like a depthwise 1D convolution, except that it is causal in a chunkwise way, as if we
|
|
had a block-triangular attention mask.The chunk size is provided at test time, it should be kept
|
|
in sync with the attention mask.
|
|
|
|
This has a little more than twice the parameters of a conventional depthwise conv1d module:
|
|
we implement it by having one depthwise convolution, of half the width, that is causal (via
|
|
right padding), and one depthwise convolution that is applied only within chunks, that we
|
|
multiply by a scaling factor which depends on the position within the chunk.
|
|
"""
|
|
|
|
def __init__(
|
|
self, num_channels: int, kernel_size: int, right_context: int, device: torch.device,
|
|
) -> None:
|
|
"""
|
|
ChunkCausalDepthwiseConv1d initialization.
|
|
|
|
Parameters
|
|
----------
|
|
num_channels : int
|
|
The number of input channels.
|
|
kernel_size : int
|
|
The kernel size for chunkwise convolution. The causal convolution kernel size is
|
|
the half of this original value. Should be an odd number.
|
|
right_context : int
|
|
The module look ahead future context, used to update module left cache correctly.
|
|
device : torch.device
|
|
The device used to store the layer weights.
|
|
Either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
if kernel_size % 2 == 0:
|
|
raise ValueError(
|
|
'Kernel size for ChunkCausalDepthwiseConv1d convolution '
|
|
f'module should be an odd number, but got {kernel_size}.',
|
|
)
|
|
|
|
self.kernel_size = kernel_size
|
|
self.right_context = right_context
|
|
|
|
self.causal_conv = torch.nn.Conv1d(
|
|
num_channels, num_channels, (kernel_size + 1) // 2, groups=num_channels, device=device,
|
|
)
|
|
|
|
self.chunkwise_conv = torch.nn.Conv1d(
|
|
num_channels,
|
|
num_channels,
|
|
kernel_size,
|
|
padding=kernel_size // 2,
|
|
groups=num_channels,
|
|
device=device,
|
|
)
|
|
|
|
# First row is correction factors added to the scale near the left edge of the chunk,
|
|
# second row is correction factors added to the scale near the right edge of the chunk,
|
|
# both of these are added to a default scale of 1.0.
|
|
self.chunkwise_conv_scale = torch.nn.Parameter(
|
|
torch.zeros(2, num_channels, kernel_size, dtype=torch.float32, device=device),
|
|
)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, left_cache: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Does a forward pass of the ChunkCausalDepthwiseConv1d module. Returns processed tensor
|
|
of the same shape as input and updated cached convolution tensor of the left context.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, num_channels, seq_len). The module input.
|
|
left_cache : torch.Tensor[torch.float32]
|
|
A cached convolution tensor of the left context
|
|
of shape (1, num_channels, left_cache_len).
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
|
|
A tuple of two float tensors:
|
|
- module output of shape (1, num_channels, seq_len).
|
|
A tensor with the same shape as input x.
|
|
- updated cached convolution tensor of the left context
|
|
of shape (1, num_channels, left_cache_len).
|
|
"""
|
|
|
|
seq_len = x.size(2)
|
|
|
|
x_chunk = self.chunkwise_conv(x) # does not change shape
|
|
|
|
x = torch.cat((left_cache, x), dim=2) # Pad with left cache
|
|
left_cache = x[
|
|
:,
|
|
:,
|
|
x.size(2) - self.right_context - left_cache.size(2):
|
|
x.size(2) - self.right_context,
|
|
] # Update cache
|
|
|
|
x_causal = self.causal_conv(x)
|
|
|
|
if seq_len < self.kernel_size:
|
|
left_edge = self.chunkwise_conv_scale[:1, :, :seq_len]
|
|
right_edge = self.chunkwise_conv_scale[1:, :, self.kernel_size - seq_len:]
|
|
else:
|
|
pad = torch.zeros(
|
|
1, self.chunkwise_conv_scale.size(1), seq_len - self.kernel_size,
|
|
dtype=torch.float32,
|
|
device=self.chunkwise_conv_scale.device,
|
|
)
|
|
left_edge = torch.cat((self.chunkwise_conv_scale[:1], pad), dim=2)
|
|
right_edge = torch.cat((pad, self.chunkwise_conv_scale[1:]), dim=2)
|
|
|
|
chunk_scale = 1.0 + left_edge + right_edge
|
|
|
|
x = x_chunk * chunk_scale + x_causal
|
|
|
|
return x, left_cache
|
|
|
|
|
|
class SwooshL(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass and returns Swoosh-L activation.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
A float tensor of an arbitrary shape (*). The module input.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of an arbitrary shape (*). A Swoosh-L activation output tensor
|
|
of the same shape as input x.
|
|
"""
|
|
|
|
logaddexp = torch.clamp(x - 4.0, min=0.0) + torch.log1p(torch.exp(-torch.abs(x - 4.0)))
|
|
|
|
return logaddexp - 0.08 * x - 0.035
|
|
|
|
|
|
class SwooshR(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass and returns Swoosh-R activation.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
A float tensor of an arbitrary shape (*). The module input.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of an arbitrary shape (*). A Swoosh-R activation output tensor
|
|
of the same shape as input x.
|
|
"""
|
|
|
|
logaddexp = torch.clamp(x - 1.0, min=0.0) + torch.log1p(torch.exp(-torch.abs(x - 1.0)))
|
|
|
|
return logaddexp - 0.08 * x - 0.313261687
|
|
|
|
|
|
def _test_whiten():
|
|
for proportion in [0.1, 0.5, 10.0]:
|
|
logging.info(f"_test_whiten(): proportion = {proportion}")
|
|
x = torch.randn(100, 128)
|
|
direction = torch.randn(128)
|
|
coeffs = torch.randn(100, 1)
|
|
x += proportion * direction * coeffs
|
|
|
|
x.requires_grad = True
|
|
|
|
m = Whiten(
|
|
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
|
) # grad_scale
|
|
|
|
for _ in range(4):
|
|
y = m(x)
|
|
|
|
y_grad = torch.randn_like(x)
|
|
y.backward(gradient=y_grad)
|
|
|
|
if proportion < 0.2:
|
|
assert torch.allclose(x.grad, y_grad)
|
|
elif proportion > 1.0:
|
|
assert not torch.allclose(x.grad, y_grad)
|
|
|
|
|
|
def _test_balancer_sign():
|
|
probs = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = Balancer(
|
|
probs.numel(),
|
|
channel_dim=0,
|
|
min_positive=0.05,
|
|
max_positive=0.95,
|
|
min_abs=0.0,
|
|
prob=1.0,
|
|
)
|
|
|
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_balancer_sign: x = ", x)
|
|
print("_test_balancer_sign: y grad = ", y_grad)
|
|
print("_test_balancer_sign: x grad = ", x.grad)
|
|
|
|
|
|
def _test_balancer_magnitude():
|
|
magnitudes = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = Balancer(
|
|
magnitudes.numel(),
|
|
channel_dim=0,
|
|
min_positive=0.0,
|
|
max_positive=1.0,
|
|
min_abs=0.2,
|
|
max_abs=0.7,
|
|
prob=1.0,
|
|
)
|
|
|
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_balancer_magnitude: x = ", x)
|
|
print("_test_balancer_magnitude: y grad = ", y_grad)
|
|
print("_test_balancer_magnitude: x grad = ", x.grad)
|
|
|
|
|
|
def _test_double_swish_deriv():
|
|
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
|
x.requires_grad = True
|
|
m = DoubleSwish()
|
|
|
|
tol = (1.2 - (-0.043637)) / 255.0
|
|
torch.autograd.gradcheck(m, x, atol=tol)
|
|
|
|
# for self-test.
|
|
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
|
x.requires_grad = True
|
|
y = m(x)
|
|
|
|
|
|
def _test_swooshl_deriv():
|
|
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
|
x.requires_grad = True
|
|
m = SwooshL()
|
|
|
|
tol = 1.0 / 255.0
|
|
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
|
|
|
# for self-test.
|
|
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
|
x.requires_grad = True
|
|
y = m(x)
|
|
|
|
|
|
def _test_swooshr_deriv():
|
|
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
|
x.requires_grad = True
|
|
m = SwooshR()
|
|
|
|
tol = 1.0 / 255.0
|
|
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
|
|
|
# for self-test.
|
|
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
|
x.requires_grad = True
|
|
y = m(x)
|
|
|
|
|
|
def _test_softmax():
|
|
a = torch.randn(2, 10, dtype=torch.float64)
|
|
b = a.clone()
|
|
a.requires_grad = True
|
|
b.requires_grad = True
|
|
a.softmax(dim=1)[:, 0].sum().backward()
|
|
print("a grad = ", a.grad)
|
|
softmax(b, dim=1)[:, 0].sum().backward()
|
|
print("b grad = ", b.grad)
|
|
assert torch.allclose(a.grad, b.grad)
|
|
|
|
|
|
def _test_piecewise_linear():
|
|
p = PiecewiseLinear((0, 10.0))
|
|
for x in [-100, 0, 100]:
|
|
assert p(x) == 10.0
|
|
p = PiecewiseLinear((0, 10.0), (1, 0.0))
|
|
for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
|
|
print("x, y = ", x, y)
|
|
assert p(x) == y, (x, p(x), y)
|
|
|
|
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
|
|
x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
|
|
pq = p.max(q)
|
|
for x in x_vals:
|
|
y1 = max(p(x), q(x))
|
|
y2 = pq(x)
|
|
assert abs(y1 - y2) < 0.001
|
|
pq = p.min(q)
|
|
for x in x_vals:
|
|
y1 = min(p(x), q(x))
|
|
y2 = pq(x)
|
|
assert abs(y1 - y2) < 0.001
|
|
pq = p + q
|
|
for x in x_vals:
|
|
y1 = p(x) + q(x)
|
|
y2 = pq(x)
|
|
assert abs(y1 - y2) < 0.001
|
|
|
|
|
|
def _test_activation_dropout_and_linear():
|
|
in_channels = 20
|
|
out_channels = 30
|
|
|
|
for bias in [True, False]:
|
|
# actually we don't test for dropout_p != 0.0 because forward functions will give
|
|
# different answers. This is because we are using the k2 implementation of
|
|
# swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
|
|
# internally, messing up the random state.
|
|
for dropout_p in [0.0]:
|
|
for activation in ["SwooshL", "SwooshR"]:
|
|
m1 = nn.Sequential(
|
|
SwooshL() if activation == "SwooshL" else SwooshR(),
|
|
Dropout3(p=dropout_p, shared_dim=-1),
|
|
ScaledLinear(
|
|
in_channels, out_channels, bias=bias, initial_scale=0.5
|
|
),
|
|
)
|
|
m2 = ActivationDropoutAndLinear(
|
|
in_channels,
|
|
out_channels,
|
|
bias=bias,
|
|
initial_scale=0.5,
|
|
activation=activation,
|
|
dropout_p=dropout_p,
|
|
)
|
|
with torch.no_grad():
|
|
m2.weight[:] = m1[2].weight
|
|
if bias:
|
|
m2.bias[:] = m1[2].bias
|
|
# make sure forward gives same result.
|
|
x1 = torch.randn(10, in_channels)
|
|
x1.requires_grad = True
|
|
|
|
# TEMP.
|
|
assert torch.allclose(
|
|
SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
|
|
)
|
|
|
|
x2 = x1.clone().detach()
|
|
x2.requires_grad = True
|
|
seed = 10
|
|
torch.manual_seed(seed)
|
|
y1 = m1(x1)
|
|
y_grad = torch.randn_like(y1)
|
|
y1.backward(gradient=y_grad)
|
|
torch.manual_seed(seed)
|
|
y2 = m2(x2)
|
|
y2.backward(gradient=y_grad)
|
|
|
|
print(
|
|
f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
|
|
)
|
|
print("y1 = ", y1)
|
|
print("y2 = ", y2)
|
|
assert torch.allclose(y1, y2, atol=0.02)
|
|
assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
|
|
if bias:
|
|
assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
|
|
print("x1.grad = ", x1.grad)
|
|
print("x2.grad = ", x2.grad)
|
|
|
|
def isclose(a, b):
|
|
# return true if cosine similarity is > 0.9.
|
|
return (a * b).sum() > 0.9 * (
|
|
(a**2).sum() * (b**2).sum()
|
|
).sqrt()
|
|
|
|
# the SwooshL() implementation has a noisy gradient due to 1-byte
|
|
# storage of it.
|
|
assert isclose(x1.grad, x2.grad)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
_test_piecewise_linear()
|
|
_test_softmax()
|
|
_test_whiten()
|
|
_test_balancer_sign()
|
|
_test_balancer_magnitude()
|
|
_test_double_swish_deriv()
|
|
_test_swooshr_deriv()
|
|
_test_swooshl_deriv()
|
|
_test_activation_dropout_and_linear()
|