mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
416 lines
13 KiB
Python
416 lines
13 KiB
Python
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import collections
|
|
from itertools import repeat
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.nn import Embedding as ScaledEmbedding
|
|
|
|
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, collections.Iterable):
|
|
return x
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
|
|
_single = _ntuple(1)
|
|
_pair = _ntuple(2)
|
|
|
|
|
|
class ActivationBalancerFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
x: Tensor,
|
|
channel_dim: int,
|
|
min_positive: float, # e.g. 0.05
|
|
max_positive: float, # e.g. 0.95
|
|
max_factor: float, # e.g. 0.01
|
|
min_abs: float, # e.g. 0.2
|
|
max_abs: float, # e.g. 100.0
|
|
) -> Tensor:
|
|
if x.requires_grad:
|
|
if channel_dim < 0:
|
|
channel_dim += x.ndim
|
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
|
xgt0 = x > 0
|
|
proportion_positive = torch.mean(
|
|
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
|
)
|
|
factor1 = (
|
|
(min_positive - proportion_positive).relu()
|
|
* (max_factor / min_positive)
|
|
if min_positive != 0.0
|
|
else 0.0
|
|
)
|
|
factor2 = (
|
|
(proportion_positive - max_positive).relu()
|
|
* (max_factor / (max_positive - 1.0))
|
|
if max_positive != 1.0
|
|
else 0.0
|
|
)
|
|
factor = factor1 + factor2
|
|
if isinstance(factor, float):
|
|
factor = torch.zeros_like(proportion_positive)
|
|
|
|
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
|
below_threshold = mean_abs < min_abs
|
|
above_threshold = mean_abs > max_abs
|
|
|
|
ctx.save_for_backward(
|
|
factor, xgt0, below_threshold, above_threshold
|
|
)
|
|
ctx.max_factor = max_factor
|
|
ctx.sum_dims = sum_dims
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(
|
|
ctx, x_grad: Tensor
|
|
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
|
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
|
dtype = x_grad.dtype
|
|
scale_factor = (
|
|
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
|
* (xgt0.to(dtype) - 0.5)
|
|
* (ctx.max_factor * 2.0)
|
|
)
|
|
|
|
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
|
|
|
|
|
class BasicNorm(torch.nn.Module):
|
|
"""
|
|
This is intended to be a simpler, and hopefully cheaper, replacement for
|
|
LayerNorm. The observation this is based on, is that Transformer-type
|
|
networks, especially with pre-norm, sometimes seem to set one of the
|
|
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
|
the LayerNorm because the output magnitude is then not strongly dependent
|
|
on the other (useful) features. Presumably the weight and bias of the
|
|
LayerNorm are required to allow it to do this.
|
|
|
|
So the idea is to introduce this large constant value as an explicit
|
|
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
|
doesn't have to do this trick. We make the "eps" learnable.
|
|
|
|
Args:
|
|
num_channels: the number of channels, e.g. 512.
|
|
channel_dim: the axis/dimension corresponding to the channel,
|
|
interprted as an offset from the input's ndim if negative.
|
|
shis is NOT the num_channels; it should typically be one of
|
|
{-2, -1, 0, 1, 2, 3}.
|
|
eps: the initial "epsilon" that we add as ballast in:
|
|
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
|
Note: our epsilon is actually large, but we keep the name
|
|
to indicate the connection with conventional LayerNorm.
|
|
learn_eps: if true, we learn epsilon; if false, we keep it
|
|
at the initial value.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_channels: int,
|
|
channel_dim: int = -1, # CAUTION: see documentation.
|
|
eps: float = 0.25,
|
|
learn_eps: bool = True,
|
|
) -> None:
|
|
super(BasicNorm, self).__init__()
|
|
self.num_channels = num_channels
|
|
self.channel_dim = channel_dim
|
|
if learn_eps:
|
|
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
|
else:
|
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
assert x.shape[self.channel_dim] == self.num_channels
|
|
scales = (
|
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
|
+ self.eps.exp()
|
|
) ** -0.5
|
|
return x * scales
|
|
|
|
|
|
class ScaledLinear(nn.Linear):
|
|
"""
|
|
A modified version of nn.Linear that gives an easy way to set the
|
|
default initial parameter scale.
|
|
|
|
Args:
|
|
Accepts the standard args and kwargs that nn.Linear accepts
|
|
e.g. in_features, out_features, bias=False.
|
|
|
|
initial_scale: you can override this if you want to increase
|
|
or decrease the initial magnitude of the module's output
|
|
(affects the initialization of weight_scale and bias_scale).
|
|
Another option, if you want to do something like this, is
|
|
to re-initialize the parameters.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
initial_scale: float = 1.0,
|
|
**kwargs
|
|
):
|
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
|
with torch.no_grad():
|
|
self.weight[:] *= initial_scale
|
|
if self.bias is not None:
|
|
torch.nn.init.uniform_(self.bias, -0.1, 0.1)
|
|
|
|
def get_weight(self): # not needed any more but kept for back compatibility
|
|
return self.weight
|
|
|
|
def get_bias(self):
|
|
return self.bias
|
|
|
|
|
|
|
|
class ScaledConv1d(nn.Conv1d):
|
|
# See docs for ScaledLinear
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
initial_scale: float = 1.0,
|
|
**kwargs
|
|
):
|
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
|
with torch.no_grad():
|
|
self.weight[:] *= initial_scale
|
|
if self.bias is not None:
|
|
torch.nn.init.uniform_(self.bias, -0.1, 0.1)
|
|
|
|
def get_weight(self): # TODO: delete
|
|
return self.weight
|
|
|
|
def get_bias(self): # TODO: delete
|
|
return self.bias
|
|
|
|
|
|
class ScaledConv2d(nn.Conv2d):
|
|
# See docs for ScaledLinear
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
initial_scale: float = 1.0,
|
|
**kwargs
|
|
):
|
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
|
with torch.no_grad():
|
|
self.weight[:] *= initial_scale
|
|
if self.bias is not None:
|
|
torch.nn.init.uniform_(self.bias, -0.1, 0.1)
|
|
|
|
def get_weight(self):
|
|
return self.weight
|
|
|
|
def get_bias(self):
|
|
return self.bias
|
|
|
|
|
|
|
|
class ActivationBalancer(torch.nn.Module):
|
|
"""
|
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
|
each channel, that it is positive at least a proportion `threshold` of the
|
|
time. It does this by multiplying negative derivative values by up to
|
|
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
|
interpolated from 1 at the threshold to those extremal values when none
|
|
of the inputs are positive.
|
|
|
|
|
|
Args:
|
|
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
|
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
|
min_positive: the minimum, per channel, of the proportion of the time
|
|
that (x > 0), below which we start to modify the derivatives.
|
|
max_positive: the maximum, per channel, of the proportion of the time
|
|
that (x > 0), above which we start to modify the derivatives.
|
|
max_factor: the maximum factor by which we modify the derivatives for
|
|
either the sign constraint or the magnitude constraint;
|
|
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
|
values in the range [0.98..1.02].
|
|
min_abs: the minimum average-absolute-value per channel, which
|
|
we allow, before we start to modify the derivatives to prevent
|
|
this.
|
|
max_abs: the maximum average-absolute-value per channel, which
|
|
we allow, before we start to modify the derivatives to prevent
|
|
this.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channel_dim: int,
|
|
min_positive: float = 0.05,
|
|
max_positive: float = 0.95,
|
|
max_factor: float = 0.01,
|
|
min_abs: float = 0.2,
|
|
max_abs: float = 100.0,
|
|
):
|
|
super(ActivationBalancer, self).__init__()
|
|
self.channel_dim = channel_dim
|
|
self.min_positive = min_positive
|
|
self.max_positive = max_positive
|
|
self.max_factor = max_factor
|
|
self.min_abs = min_abs
|
|
self.max_abs = max_abs
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
if torch.jit.is_scripting():
|
|
return x
|
|
|
|
return ActivationBalancerFunction.apply(
|
|
x,
|
|
self.channel_dim,
|
|
self.min_positive,
|
|
self.max_positive,
|
|
self.max_factor,
|
|
self.min_abs,
|
|
self.max_abs,
|
|
)
|
|
|
|
|
|
class DoubleSwishFunction(torch.autograd.Function):
|
|
"""
|
|
double_swish(x) = x * torch.sigmoid(x-1)
|
|
This is a definition, originally motivated by its close numerical
|
|
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
|
|
|
Memory-efficient derivative computation:
|
|
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
|
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
|
Now, s'(x) = s(x) * (1-s(x)).
|
|
double_swish'(x) = x * s'(x) + s(x).
|
|
= x * s(x) * (1-s(x)) + s(x).
|
|
= double_swish(x) * (1-s(x)) + s(x)
|
|
... so we just need to remember s(x) but not x itself.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x: Tensor) -> Tensor:
|
|
x = x.detach()
|
|
s = torch.sigmoid(x - 1.0)
|
|
y = x * s
|
|
ctx.save_for_backward(s, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
s, y = ctx.saved_tensors
|
|
return (y * (1 - s) + s) * y_grad
|
|
|
|
|
|
class DoubleSwish(torch.nn.Module):
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
|
that we approximate closely with x * sigmoid(x-1).
|
|
"""
|
|
if torch.jit.is_scripting():
|
|
return x * torch.sigmoid(x - 1.0)
|
|
return DoubleSwishFunction.apply(x)
|
|
|
|
|
|
|
|
|
|
def _test_activation_balancer_sign():
|
|
probs = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = ActivationBalancer(
|
|
channel_dim=0,
|
|
min_positive=0.05,
|
|
max_positive=0.95,
|
|
max_factor=0.2,
|
|
min_abs=0.0,
|
|
)
|
|
|
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_activation_balancer_sign: x = ", x)
|
|
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
|
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
|
|
|
|
|
def _test_activation_balancer_magnitude():
|
|
magnitudes = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
|
-1
|
|
)
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = ActivationBalancer(
|
|
channel_dim=0,
|
|
min_positive=0.0,
|
|
max_positive=1.0,
|
|
max_factor=0.2,
|
|
min_abs=0.2,
|
|
max_abs=0.8,
|
|
)
|
|
|
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_activation_balancer_magnitude: x = ", x)
|
|
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
|
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
|
|
|
|
|
def _test_basic_norm():
|
|
num_channels = 128
|
|
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
|
|
|
x = torch.randn(500, num_channels)
|
|
|
|
y = m(x)
|
|
|
|
assert y.shape == x.shape
|
|
x_rms = (x ** 2).mean().sqrt()
|
|
y_rms = (y ** 2).mean().sqrt()
|
|
print("x rms = ", x_rms)
|
|
print("y rms = ", y_rms)
|
|
assert y_rms < x_rms
|
|
assert y_rms > 0.5 * x_rms
|
|
|
|
|
|
def _test_double_swish_deriv():
|
|
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
|
x.requires_grad = True
|
|
m = DoubleSwish()
|
|
torch.autograd.gradcheck(m, x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test_activation_balancer_sign()
|
|
_test_activation_balancer_magnitude()
|
|
_test_basic_norm()
|
|
_test_double_swish_deriv()
|