Replace ActivationBalancer with Balancer

This commit is contained in:
Daniel Povey 2022-12-29 20:34:46 +08:00
parent c6bad1ee4f
commit 59be36181c
2 changed files with 246 additions and 16 deletions

View File

@ -23,6 +23,7 @@ import logging
from torch.cuda.amp import custom_fwd, custom_bwd from torch.cuda.amp import custom_fwd, custom_bwd
import random import random
import torch import torch
import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
@ -30,7 +31,6 @@ from torch.nn import Embedding as ScaledEmbedding
class ScheduledFloat(torch.nn.Module): class ScheduledFloat(torch.nn.Module):
""" """
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
@ -975,6 +975,179 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x) return _no_op(x)
class BalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
min_mean: float,
max_mean: float,
min_rms: float,
max_rms: float,
grad_scale: float,
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
ctx.save_for_backward(x)
ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
return x
@staticmethod
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None]:
x, = ctx.saved_tensors
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x = x.to(torch.float32)
x = x.detach()
x.requires_grad = True
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
uncentered_var = (x ** 2).mean(dim=mean_dims)
mean = x.mean(dim=mean_dims)
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
m = mean / stddev
# part of loss that relates to mean / stddev
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
r_loss = (rms_clamped / rms).log().abs()
loss = (m_loss + r_loss).sum()
loss.backward()
loss_grad = x.grad
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
x_grad_float = x_grad.to(torch.float32)
# scale each element of loss_grad by the absolute value of the corresponding
# element of x_grad, which we view as a noisy estimate of its magnitude for that
# (frame and dimension). later we can consider factored versions.
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
x_grad = x_grad_mod.to(x_grad.dtype)
return x_grad, None, None, None, None, None, None
class Balancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
each channel, that it is positive at least a proportion `threshold` of the
time. It does this by multiplying negative derivative values by up to
(1+max_factor), and positive derivative values by up to (1-max_factor),
interpolated from 1 at the threshold to those extremal values when none
of the inputs are positive.
Args:
num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_positive: the maximum, per channel, of the proportion of the time
that (x > 0), above which we start to modify the derivatives.
scale_gain_factor: determines the 'gain' with which we increase the
change in gradient once the constraints on min_abs and max_abs
are violated.
min_abs: the minimum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
prob: determines the minimum probability with which we modify the
gradients for the {min,max}_positive and {min,max}_abs constraints,
on each forward(). This is done randomly to prevent all layers
from doing it at the same time.
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
min_positive: FloatLike = 0.05,
max_positive: FloatLike = 0.95,
min_abs: FloatLike = 0.2,
max_abs: FloatLike = 100.0,
grad_scale: FloatLike = 0.02,
prob: Optional[FloatLike] = None,
):
super().__init__()
if prob is None:
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
self.prob = prob
# 5% of the time we will return and do nothing because memory usage is
# too high.
self.mem_cutoff = CutoffEstimator(0.05)
# actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels
self.channel_dim = channel_dim
self.min_positive = min_positive
self.max_positive = max_positive
self.min_abs = min_abs
self.max_abs = max_abs
self.grad_scale = grad_scale
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
return _no_op(x)
prob = float(self.prob)
if random.random() < prob:
# The following inner-functions convert from the way we historically specified
# these limitations, as limits on the absolute value and the proportion of positive
# values, to limits on the RMS value and the (mean / stddev).
def _abs_to_rms(x):
# for normally distributed data, if the expected absolute value is x, the
# expected rms value will be sqrt(pi/2) * x.
return 1.25331413732 * x
def _proportion_positive_to_mean(x):
def _atanh(x):
eps = 1.0e-10
# eps is to prevent crashes if x is exactly 0 or 1.
# we'll just end up returning a fairly large value.
return (math.log (1+x+eps) - math.log (1-x+eps)) / 2.
def _approx_inverse_erf(x):
# 1 / (sqrt(pi) * ln(2)),
# see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
# this approximation is extremely crude and gets progressively worse for
# x very close to -1 or +1, but we mostly care about the "middle" region
# e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
# and math.erf(0.0407316414078772) = 0.045935330944660666,
# which is pretty close to 0.05.
return 0.8139535143 * _atanh(x)
# first convert x from the range 0..1 to the range -1..1 which the error
# function returns
x = -1 + (2 * x)
return _approx_inverse_erf(x)
min_mean = _proportion_positive_to_mean(float(self.min_positive))
max_mean = _proportion_positive_to_mean(float(self.max_positive))
min_rms = _abs_to_rms(float(self.min_abs))
max_rms = _abs_to_rms(float(self.max_abs))
grad_scale = float(self.grad_scale)
assert x.shape[self.channel_dim] == self.num_channels
return BalancerFunction.apply(
x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
)
else:
return _no_op(x)
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
name: str = None) -> Tensor: name: str = None) -> Tensor:
""" """
@ -1731,6 +1904,7 @@ def _test_activation_balancer_sign():
max_positive=0.95, max_positive=0.95,
max_factor=0.2, max_factor=0.2,
min_abs=0.0, min_abs=0.0,
prob=1.0,
) )
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -1742,6 +1916,31 @@ def _test_activation_balancer_sign():
print("_test_activation_balancer_sign: x grad = ", x.grad) print("_test_activation_balancer_sign: x grad = ", x.grad)
def _test_balancer_sign():
probs = torch.arange(0, 1, 0.01)
N = 1000
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
x = x.detach()
x.requires_grad = True
m = Balancer(
probs.numel(),
channel_dim=0,
min_positive=0.05,
max_positive=0.95,
min_abs=0.0,
prob=1.0,
)
y_grad = torch.sign(torch.randn(probs.numel(), N))
y = m(x)
y.backward(gradient=y_grad)
print("_test_balancer_sign: x = ", x)
print("_test_balancer_sign: y grad = ", y_grad)
print("_test_balancer_sign: x grad = ", x.grad)
def _test_activation_balancer_magnitude(): def _test_activation_balancer_magnitude():
magnitudes = torch.arange(0, 1, 0.01) magnitudes = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
@ -1770,6 +1969,34 @@ def _test_activation_balancer_magnitude():
print("_test_activation_balancer_magnitude: x grad = ", x.grad) print("_test_activation_balancer_magnitude: x grad = ", x.grad)
def _test_balancer_magnitude():
magnitudes = torch.arange(0, 1, 0.01)
N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
-1
)
x = x.detach()
x.requires_grad = True
m = Balancer(
magnitudes.numel(),
channel_dim=0,
min_positive=0.0,
max_positive=1.0,
min_abs=0.2,
max_abs=0.7,
prob=1.0,
)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
y = m(x)
y.backward(gradient=y_grad)
print("_test_balancer_magnitude: x = ", x)
print("_test_balancer_magnitude: y grad = ", y_grad)
print("_test_balancer_magnitude: x grad = ", x.grad)
def _test_basic_norm(): def _test_basic_norm():
num_channels = 128 num_channels = 128
m = BasicNorm(num_channels=num_channels, channel_dim=1) m = BasicNorm(num_channels=num_channels, channel_dim=1)
@ -1862,7 +2089,9 @@ if __name__ == "__main__":
_test_whiten() _test_whiten()
_test_max_eig() _test_max_eig()
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_balancer_magnitude()
_test_basic_norm() _test_basic_norm()
_test_double_swish_deriv() _test_double_swish_deriv()
_test_tan_swish_deriv() _test_tan_swish_deriv()

View File

@ -26,6 +26,7 @@ import random
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
ActivationBalancer, ActivationBalancer,
Balancer,
BasicNorm, BasicNorm,
ConvNorm1d, ConvNorm1d,
ConvNorm2d, ConvNorm2d,
@ -456,7 +457,7 @@ class ZipformerEncoderLayer(nn.Module):
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
# try to ensure the output is close to zero-mean (or at least, zero-median). # try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer( self.balancer = Balancer(
embed_dim, channel_dim=-1, embed_dim, channel_dim=-1,
min_positive=0.45, max_positive=0.55, min_positive=0.45, max_positive=0.55,
min_abs=1.0, max_abs=6.0, min_abs=1.0, max_abs=6.0,
@ -1302,7 +1303,7 @@ class AttentionSqueeze(nn.Module):
# instances of this module, the mean absolute values per channel are in # instances of this module, the mean absolute values per channel are in
# the range 0.1 to 0.4. We apply the upper limit of 0.4 at the # the range 0.1 to 0.4. We apply the upper limit of 0.4 at the
# beginning, and make it looser over time. # beginning, and make it looser over time.
self.bottleneck_balancer = ActivationBalancer( self.bottleneck_balancer = Balancer(
bottleneck_dim, channel_dim=-1, bottleneck_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8, min_positive=0.2, max_positive=0.8,
min_abs=0.05, min_abs=0.05,
@ -1316,13 +1317,13 @@ class AttentionSqueeze(nn.Module):
# many degrees of freedom for the scales of the various activations. # many degrees of freedom for the scales of the various activations.
# Make them run with very low probability, since only a small # Make them run with very low probability, since only a small
# application of these balancers should be enough to stop such "drift". # application of these balancers should be enough to stop such "drift".
self.scale_balancer = ActivationBalancer( self.scale_balancer = Balancer(
hidden_dim, channel_dim=-1, hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8, min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0, min_abs=0.2, max_abs=1.0,
prob=_balancer_schedule(0.05), prob=_balancer_schedule(0.05),
) )
self.activation_balancer = ActivationBalancer( self.activation_balancer = Balancer(
hidden_dim, channel_dim=-1, hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8, min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0, min_abs=0.2, max_abs=1.0,
@ -1339,7 +1340,7 @@ class AttentionSqueeze(nn.Module):
self.out_proj = ScaledLinear(hidden_dim, embed_dim, self.out_proj = ScaledLinear(hidden_dim, embed_dim,
bias=False, initial_scale=0.05) bias=False, initial_scale=0.05)
self.out_balancer = ActivationBalancer( self.out_balancer = Balancer(
embed_dim, channel_dim=-1, embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7, min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
@ -1396,7 +1397,7 @@ class FeedforwardModule(nn.Module):
super(FeedforwardModule, self).__init__() super(FeedforwardModule, self).__init__()
self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.in_proj = nn.Linear(embed_dim, feedforward_dim)
self.hidden_balancer = ActivationBalancer(feedforward_dim, self.hidden_balancer = Balancer(feedforward_dim,
channel_dim=-1, channel_dim=-1,
min_positive=0.3, min_positive=0.3,
max_positive=1.0, max_positive=1.0,
@ -1447,7 +1448,7 @@ class NonlinAttentionModule(nn.Module):
# because we noticed that well-trained instances of this module have abs-value before the sigmoid # because we noticed that well-trained instances of this module have abs-value before the sigmoid
# starting from about 3, and poorly-trained instances of the module have smaller abs values # starting from about 3, and poorly-trained instances of the module have smaller abs values
# before the sigmoid. # before the sigmoid.
self.balancer1 = ActivationBalancer( self.balancer1 = Balancer(
hidden_channels, channel_dim=-1, hidden_channels, channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
@ -1473,7 +1474,7 @@ class NonlinAttentionModule(nn.Module):
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
self.balancer2 = ActivationBalancer( self.balancer2 = Balancer(
channels, channel_dim=-1, channels, channel_dim=-1,
min_positive=0.3, max_positive=0.7, min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
@ -1566,7 +1567,7 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.balancer1 = ActivationBalancer( self.balancer1 = Balancer(
bottleneck_dim, channel_dim=-1, bottleneck_dim, channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
max_positive=1.0, max_positive=1.0,
@ -1590,7 +1591,7 @@ class ConvolutionModule(nn.Module):
bias=True, bias=True,
) )
self.balancer2 = ActivationBalancer( self.balancer2 = Balancer(
bottleneck_dim, channel_dim=1, bottleneck_dim, channel_dim=1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0, max_positive=1.0,
@ -1695,7 +1696,7 @@ class ConvNeXt(nn.Module):
out_channels=hidden_channels, out_channels=hidden_channels,
kernel_size=1) kernel_size=1)
self.hidden_balancer = ActivationBalancer(hidden_channels, self.hidden_balancer = Balancer(hidden_channels,
channel_dim=1, channel_dim=1,
min_positive=0.3, min_positive=0.3,
max_positive=1.0, max_positive=1.0,
@ -1709,7 +1710,7 @@ class ConvNeXt(nn.Module):
kernel_size=1, kernel_size=1,
initial_scale=0.01) initial_scale=0.01)
self.out_balancer = ActivationBalancer( self.out_balancer = Balancer(
channels, channel_dim=1, channels, channel_dim=1,
min_positive=0.4, max_positive=0.6, min_positive=0.4, max_positive=0.6,
min_abs=1.0, max_abs=6.0, min_abs=1.0, max_abs=6.0,
@ -1800,7 +1801,7 @@ class Conv2dSubsampling(nn.Module):
padding=(0, 1), # (time, freq) padding=(0, 1), # (time, freq)
), ),
ScaleGrad(0.2), ScaleGrad(0.2),
ActivationBalancer(layer1_channels, Balancer(layer1_channels,
channel_dim=1, channel_dim=1,
max_abs=1.0), max_abs=1.0),
SwooshR(), SwooshR(),
@ -1811,7 +1812,7 @@ class Conv2dSubsampling(nn.Module):
stride=2, stride=2,
padding=0, padding=0,
), ),
ActivationBalancer(layer2_channels, Balancer(layer2_channels,
channel_dim=1, channel_dim=1,
max_abs=4.0), max_abs=4.0),
SwooshR(), SwooshR(),
@ -1833,7 +1834,7 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3, kernel_size=3,
stride=(1, 2), # (time, freq) stride=(1, 2), # (time, freq)
), ),
ActivationBalancer(layer3_channels, Balancer(layer3_channels,
channel_dim=1, channel_dim=1,
max_abs=4.0), max_abs=4.0),
SwooshR(), SwooshR(),