mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
680 lines
25 KiB
Python
680 lines
25 KiB
Python
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from typing import Tuple
|
|
|
|
|
|
class Conv2dSubsampling(nn.Module):
|
|
"""Convolutional 2D subsampling (to 1/4 length).
|
|
|
|
Convert an input of shape (N, T, idim) to an output
|
|
with shape (N, T', odim), where
|
|
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
|
|
|
It is based on
|
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
|
"""
|
|
|
|
def __init__(self, idim: int, odim: int) -> None:
|
|
"""
|
|
Args:
|
|
idim:
|
|
Input dim. The input shape is (N, T, idim).
|
|
Caution: It requires: T >=7, idim >=7
|
|
odim:
|
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
|
"""
|
|
assert idim >= 7
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
ScaledConv2d(
|
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
|
),
|
|
DerivBalancer(channel_dim=1),
|
|
DoubleSwish(),
|
|
ScaledConv2d(
|
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
|
),
|
|
DerivBalancer(channel_dim=1),
|
|
DoubleSwish(),
|
|
)
|
|
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
|
self.out_norm = BasicNorm(odim)
|
|
# constrain mean of output to be close to zero.
|
|
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
# init weights with smaller than default variance, because otherwise
|
|
# they learn too slowly in relative terms (assuming we're training with adam).
|
|
nn.init.normal_(self.conv[0].weight, std=0.05)
|
|
nn.init.constant_(self.conv[0].bias, 0.0)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x:
|
|
Its shape is (N, T, idim).
|
|
|
|
Returns:
|
|
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
|
"""
|
|
# On entry, x is (N, T, idim)
|
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
|
x = self.conv(x)
|
|
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
|
x = self.out_norm(x)
|
|
x = self.out_balancer(x)
|
|
return x
|
|
|
|
|
|
class VggSubsampling(nn.Module):
|
|
"""Trying to follow the setup described in the following paper:
|
|
https://arxiv.org/pdf/1910.09799.pdf
|
|
|
|
This paper is not 100% explicit so I am guessing to some extent,
|
|
and trying to compare with other VGG implementations.
|
|
|
|
Convert an input of shape (N, T, idim) to an output
|
|
with shape (N, T', odim), where
|
|
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
|
"""
|
|
|
|
def __init__(self, idim: int, odim: int) -> None:
|
|
"""Construct a VggSubsampling object.
|
|
|
|
This uses 2 VGG blocks with 2 Conv2d layers each,
|
|
subsampling its input by a factor of 4 in the time dimensions.
|
|
|
|
Args:
|
|
idim:
|
|
Input dim. The input shape is (N, T, idim).
|
|
Caution: It requires: T >=7, idim >=7
|
|
odim:
|
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
|
"""
|
|
super().__init__()
|
|
|
|
cur_channels = 1
|
|
layers = []
|
|
block_dims = [32, 64]
|
|
|
|
# The decision to use padding=1 for the 1st convolution, then padding=0
|
|
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
|
# a back-compatibility concern so that the number of frames at the
|
|
# output would be equal to:
|
|
# (((T-1)//2)-1)//2.
|
|
# We can consider changing this by using padding=1 on the
|
|
# 2nd convolution, so the num-frames at the output would be T//4.
|
|
for block_dim in block_dims:
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=cur_channels,
|
|
out_channels=block_dim,
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=1,
|
|
)
|
|
)
|
|
layers.append(torch.nn.ReLU())
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=block_dim,
|
|
out_channels=block_dim,
|
|
kernel_size=3,
|
|
padding=0,
|
|
stride=1,
|
|
)
|
|
)
|
|
layers.append(
|
|
torch.nn.MaxPool2d(
|
|
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
|
)
|
|
)
|
|
cur_channels = block_dim
|
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
self.out = nn.Linear(
|
|
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x:
|
|
Its shape is (N, T, idim).
|
|
|
|
Returns:
|
|
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
|
"""
|
|
x = x.unsqueeze(1)
|
|
x = self.layers(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
return x
|
|
|
|
|
|
class PeLUFunction(torch.autograd.Function):
|
|
"""
|
|
Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)).
|
|
The function is:
|
|
x.relu() + alpha * (cutoff - x).relu()
|
|
E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off
|
|
of neurons.
|
|
"""
|
|
@staticmethod
|
|
def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor:
|
|
mask1 = (x >= 0) # >=, so there is deriv if x == 0.
|
|
p = cutoff - x
|
|
mask2 = (p >= 0)
|
|
ctx.save_for_backward(mask1, mask2)
|
|
ctx.alpha = alpha
|
|
return x.relu() + alpha * p.relu()
|
|
@staticmethod
|
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]:
|
|
mask1, mask2 = ctx.saved_tensors
|
|
return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None
|
|
|
|
|
|
|
|
class PeLU(torch.nn.Module):
|
|
def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None:
|
|
super(PeLU, self).__init__()
|
|
self.cutoff = cutoff
|
|
self.alpha = alpha
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return PeLUFunction.apply(x, self.cutoff, self.alpha)
|
|
|
|
class ExpScale(torch.nn.Module):
|
|
def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0):
|
|
super(ExpScale, self).__init__()
|
|
scale = torch.tensor(initial_scale)
|
|
scale = scale.log() / speed
|
|
self.scale = nn.Parameter(scale.detach())
|
|
self.speed = speed
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x * (self.scale * self.speed).exp()
|
|
|
|
|
|
|
|
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
# double-swish, implemented/approximated as offset-swish
|
|
x = (x * torch.sigmoid(x - 1.0))
|
|
x = x * (scale * speed).exp()
|
|
return x
|
|
|
|
class SwishExpScaleFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
ctx.save_for_backward(x.detach(), scale.detach())
|
|
ctx.speed = speed
|
|
return _exp_scale_swish(x, scale, speed)
|
|
|
|
@staticmethod
|
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
x, scale = ctx.saved_tensors
|
|
x.requires_grad = True
|
|
scale.requires_grad = True
|
|
with torch.enable_grad():
|
|
y = _exp_scale_swish(x, scale, ctx.speed)
|
|
y.backward(gradient=y_grad)
|
|
return x.grad, scale.grad, None
|
|
|
|
|
|
class SwishExpScale(torch.nn.Module):
|
|
# combines ExpScale and a Swish (actually the ExpScale is after the Swish).
|
|
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
|
|
#
|
|
def __init__(self, *shape, speed: float = 1.0):
|
|
super(SwishExpScale, self).__init__()
|
|
|
|
initial_log_scale = torch.zeros(()).detach()
|
|
self.scale = nn.Parameter(initial_log_scale)
|
|
self.speed = speed
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return SwishExpScaleFunction.apply(x, self.scale, self.speed)
|
|
# x = (x * torch.sigmoid(x))
|
|
# x = (x * torch.sigmoid(x))
|
|
# x = x * (self.scale * self.speed).exp()
|
|
# return x
|
|
|
|
|
|
|
|
def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
return (x * (scale * speed).exp()).relu()
|
|
|
|
|
|
|
|
|
|
class ExpScaleReluFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
ctx.save_for_backward(x.detach(), scale.detach())
|
|
ctx.speed = speed
|
|
return _exp_scale_relu(x, scale, speed)
|
|
|
|
@staticmethod
|
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
x, scale = ctx.saved_tensors
|
|
x.requires_grad = True
|
|
scale.requires_grad = True
|
|
with torch.enable_grad():
|
|
y = _exp_scale_relu(x, scale, ctx.speed)
|
|
y.backward(gradient=y_grad)
|
|
return x.grad, scale.grad, None
|
|
|
|
class ExpScaleRelu(torch.nn.Module):
|
|
# combines ExpScale and Relu.
|
|
# caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0)
|
|
def __init__(self, *shape, speed: float = 1.0):
|
|
super(ExpScaleRelu, self).__init__()
|
|
self.scale = nn.Parameter(torch.zeros(*shape))
|
|
self.speed = speed
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return ExpScaleReluFunction.apply(x, self.scale, self.speed)
|
|
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
|
|
# return x * (self.scale * self.speed).exp()
|
|
|
|
|
|
|
|
|
|
class DerivBalancerFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x: Tensor,
|
|
channel_dim: int,
|
|
min_positive: float, # e.g. 0.05
|
|
max_positive: float, # e.g. 0.95
|
|
max_factor: float, # e.g. 0.01
|
|
min_abs: float, # e.g. 0.2
|
|
max_abs: float, # e.g. 100.0
|
|
) -> Tensor:
|
|
if x.requires_grad:
|
|
if channel_dim < 0:
|
|
channel_dim += x.ndim
|
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
|
xgt0 = x > 0
|
|
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
|
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
|
|
if min_positive != 0.0 else 0.0)
|
|
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
|
|
if max_positive != 1.0 else 0.0)
|
|
factor = factor1 + factor2
|
|
if isinstance(factor, float):
|
|
factor = torch.zeros_like(proportion_positive)
|
|
|
|
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
|
below_threshold = (mean_abs < min_abs)
|
|
above_threshold = (mean_abs > max_abs)
|
|
|
|
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
|
ctx.max_factor = max_factor
|
|
ctx.sum_dims = sum_dims
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
|
|
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
|
dtype = x_grad.dtype
|
|
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
|
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
|
|
|
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
|
|
|
|
|
class BasicNorm(torch.nn.Module):
|
|
"""
|
|
This is intended to be a simpler, and hopefully cheaper, replacement for
|
|
LayerNorm. The observation this is based on, is that Transformer-type
|
|
networks, especially with pre-norm, sometimes seem to set one of the
|
|
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
|
the LayerNorm because the output magnitude is then not strongly dependent
|
|
on the other (useful) features. Presumably the weight and bias of the
|
|
LayerNorm are required to allow it to do this.
|
|
|
|
So the idea is to introduce this large constant value as an explicit
|
|
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
|
doesn't have to do this trick.
|
|
|
|
We also introduce a learned scaling factor on the output; and we
|
|
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
|
|
that useful unless the LayerNorm immediately follows a nonlinearity).
|
|
|
|
|
|
Args:
|
|
channel_dim: the axis/dimension corresponding to the channel,
|
|
interprted as an offset from the input's ndim if negative.
|
|
shis is NOT the num_channels; it should typically be one of
|
|
{-2, -1, 0, 1, 2, 3}.
|
|
initial_eps: the initial "epsilon" that we add as ballast in:
|
|
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
|
Note: our epsilon is actually large, but we keep the name
|
|
to indicate the connection with normal LayerNorm.
|
|
|
|
speed: a scaling factor that can be interpreted as scaling the learning
|
|
rate for this module. CAUTION: the default value of 10.0 intended to be
|
|
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
|
|
If you are using SGD you would probably have to set `speed` to
|
|
a value less than one, or the training would be unstable.
|
|
"""
|
|
def __init__(self,
|
|
num_channels: int,
|
|
channel_dim: int = -1, # CAUTION: see documentation.
|
|
eps: float = 0.25):
|
|
super(BasicNorm, self).__init__()
|
|
self.num_channels = num_channels
|
|
self.channel_dim = channel_dim
|
|
self.eps = eps
|
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
assert x.shape[self.channel_dim] == self.num_channels
|
|
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5
|
|
return x * scales
|
|
|
|
|
|
class ScaledLinear(nn.Linear):
|
|
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
|
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
|
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
self.scale_speed = scale_speed
|
|
if self.bias is not None:
|
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
else:
|
|
self.register_parameter('bias_scale', None)
|
|
|
|
|
|
def get_weight(self):
|
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
|
|
|
def get_bias(self):
|
|
return (None if self.bias is None else
|
|
self.bias * (self.bias_scale * self.scale_speed).exp())
|
|
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return torch.nn.functional.linear(input, self.get_weight(),
|
|
self.get_bias())
|
|
|
|
|
|
class ScaledConv1d(nn.Conv1d):
|
|
def __init__(self, *args, scale_speed = 5.0,
|
|
initial_scale=1.0, **kwargs):
|
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
|
self.scale_speed = scale_speed
|
|
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
if self.bias is not None:
|
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
else:
|
|
self.register_parameter('bias_scale', None)
|
|
|
|
def get_weight(self):
|
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
|
|
|
def get_bias(self):
|
|
return (None if self.bias is None else
|
|
self.bias * (self.bias_scale * self.scale_speed).exp())
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
F = torch.nn.functional
|
|
if self.padding_mode != 'zeros':
|
|
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
|
self.get_weight(), self.get_bias(), self.stride,
|
|
_single(0), self.dilation, self.groups)
|
|
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
|
self.padding, self.dilation, self.groups)
|
|
|
|
|
|
|
|
class ScaledConv2d(nn.Conv2d):
|
|
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
|
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
|
self.scale_speed = scale_speed
|
|
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
if self.bias is not None:
|
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
else:
|
|
self.register_parameter('bias_scale', None)
|
|
|
|
|
|
def get_weight(self):
|
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
|
|
|
def get_bias(self):
|
|
return (None if self.bias is None else
|
|
self.bias * (self.bias_scale * self.scale_speed).exp())
|
|
|
|
def _conv_forward(self, input, weight):
|
|
F = torch.nn.functional
|
|
if self.padding_mode != 'zeros':
|
|
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
|
weight, self.get_bias(), self.stride,
|
|
_pair(0), self.dilation, self.groups)
|
|
return F.conv2d(input, weight, self.get_bias(), self.stride,
|
|
self.padding, self.dilation, self.groups)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return self._conv_forward(input, self.get_weight())
|
|
|
|
|
|
|
|
|
|
class DerivBalancer(torch.nn.Module):
|
|
"""
|
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
|
each channel, that it is positive at least a proportion `threshold` of the
|
|
time. It does this by multiplying negative derivative values by up to
|
|
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
|
interpolated from 0 at the threshold to those extremal values when none
|
|
of the inputs are positive.
|
|
|
|
When all grads are zero for a channel, this
|
|
module sets all the input derivatives for that channel to -epsilon; the
|
|
idea is to bring completely dead neurons back to life this way.
|
|
|
|
Args:
|
|
channel_dim: the dimension/axi corresponding to the channel, e.g.
|
|
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
|
min_positive: the minimum, per channel, of the proportion of the time
|
|
that (x > 0), below which we start to modify the derivatives.
|
|
max_positive: the maximum, per channel, of the proportion of the time
|
|
that (x > 0), below which we start to modify the derivatives.
|
|
max_factor: the maximum factor by which we modify the derivatives,
|
|
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
|
values in the range [0.98..1.01].
|
|
zero: we use this value in the comparison (x > 0), i.e. we actually use
|
|
(x > zero). The reason for using a threshold slightly greater
|
|
than zero is that it will tend to prevent situations where the
|
|
inputs shrink close to zero and the nonlinearity (e.g. swish)
|
|
behaves like a linear function and we learn nothing.
|
|
min_abs: the minimum average-absolute-value per channel, which
|
|
we allow, before we start to modify the derivatives to prevent
|
|
this. This is to prevent a failure mode where the activations
|
|
become so small that the nonlinearity effectively becomes linear,
|
|
which makes the module useless and it gets even smaller
|
|
to try to "turn it off" completely.
|
|
max_abs: the maximum average-absolute-value per channel, which
|
|
we allow, before we start to modify the derivatives to prevent
|
|
this. This is to prevent the possibility of activations getting
|
|
out of floating point numerical range (especially in half precision).
|
|
"""
|
|
def __init__(self, channel_dim: int,
|
|
min_positive: float = 0.05,
|
|
max_positive: float = 1.0,
|
|
max_factor: float = 0.01,
|
|
min_abs: float = 0.2,
|
|
max_abs: float = 100.0):
|
|
super(DerivBalancer, self).__init__()
|
|
self.channel_dim = channel_dim
|
|
self.min_positive = min_positive
|
|
self.max_positive = max_positive
|
|
self.max_factor = max_factor
|
|
self.min_abs = min_abs
|
|
self.max_abs = max_abs
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return DerivBalancerFunction.apply(x, self.channel_dim,
|
|
self.min_positive, self.max_positive,
|
|
self.max_factor, self.min_abs,
|
|
self.max_abs)
|
|
|
|
|
|
def _double_swish(x: Tensor) -> Tensor:
|
|
# double-swish, implemented/approximated as offset-swish
|
|
return x * torch.sigmoid(x - 1.0)
|
|
|
|
class DoubleSwishFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x: Tensor) -> Tensor:
|
|
ctx.save_for_backward(x.detach())
|
|
return _double_swish(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
# TODO: can make this more efficient.
|
|
x, = ctx.saved_tensors
|
|
x.requires_grad = True
|
|
with torch.enable_grad():
|
|
y = _double_swish(x)
|
|
y.backward(gradient=y_grad)
|
|
return x.grad
|
|
|
|
class DoubleSwish(torch.nn.Module):
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
|
that we approximate closely with x * sigmoid(x-1).
|
|
"""
|
|
return DoubleSwishFunction.apply(x)
|
|
|
|
def _test_exp_scale_swish():
|
|
|
|
x1 = torch.randn(50, 60).detach()
|
|
x2 = x1.detach()
|
|
|
|
m1 = SwishExpScale(50, 1, speed=4.0)
|
|
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
|
|
x1.requires_grad = True
|
|
x2.requires_grad = True
|
|
|
|
y1 = m1(x1)
|
|
y2 = m2(x2)
|
|
assert torch.allclose(y1, y2, atol=1e-05)
|
|
y1.sum().backward()
|
|
y2.sum().backward()
|
|
assert torch.allclose(x1.grad, x2.grad, atol=1e-05)
|
|
|
|
def _test_exp_scale_relu():
|
|
|
|
x1 = torch.randn(50, 60).detach()
|
|
x2 = x1.detach()
|
|
|
|
m1 = ExpScaleRelu(50, 1, speed=4.0)
|
|
m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0))
|
|
x1.requires_grad = True
|
|
x2.requires_grad = True
|
|
|
|
y1 = m1(x1)
|
|
y2 = m2(x2)
|
|
assert torch.allclose(y1, y2)
|
|
y1.sum().backward()
|
|
y2.sum().backward()
|
|
assert torch.allclose(x1.grad, x2.grad)
|
|
|
|
|
|
|
|
def _test_deriv_balancer_sign():
|
|
channel_dim = 0
|
|
probs = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
|
max_factor=0.2, min_abs=0.0)
|
|
|
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_deriv_balancer_sign: x = ", x)
|
|
print("_test_deriv_balancer_sign: y grad = ", y_grad)
|
|
print("_test_deriv_balancer_sign: x grad = ", x.grad)
|
|
|
|
def _test_deriv_balancer_magnitude():
|
|
channel_dim = 0
|
|
magnitudes = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = DerivBalancer(channel_dim=0,
|
|
min_positive=0.0, max_positive=1.0,
|
|
max_factor=0.2,
|
|
min_abs=0.2, max_abs=0.8)
|
|
|
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_deriv_balancer_magnitude: x = ", x)
|
|
print("_test_deriv_balancer_magnitude: y grad = ", y_grad)
|
|
print("_test_deriv_balancer_magnitude: x grad = ", x.grad)
|
|
|
|
|
|
def _test_basic_norm():
|
|
num_channels = 128
|
|
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
|
|
|
x = torch.randn(500, num_channels)
|
|
|
|
y = m(x)
|
|
|
|
assert y.shape == x.shape
|
|
x_rms = (x**2).mean().sqrt()
|
|
y_rms = (y**2).mean().sqrt()
|
|
print("x rms = ", x_rms)
|
|
print("y rms = ", y_rms)
|
|
assert y_rms < x_rms
|
|
assert y_rms > 0.5 * x_rms
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
_test_deriv_balancer_sign()
|
|
_test_deriv_balancer_magnitude()
|
|
_test_exp_scale_swish()
|
|
_test_exp_scale_relu()
|
|
_test_basic_norm()
|