mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace ActivationBalancer with Balancer
This commit is contained in:
parent
c6bad1ee4f
commit
59be36181c
@ -23,6 +23,7 @@ import logging
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import random
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
@ -30,7 +31,6 @@ from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
@ -975,6 +975,179 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
|
||||
class BalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
min_mean: float,
|
||||
max_mean: float,
|
||||
min_rms: float,
|
||||
max_rms: float,
|
||||
grad_scale: float,
|
||||
channel_dim: int,
|
||||
) -> Tensor:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
ctx.channel_dim = channel_dim
|
||||
ctx.save_for_backward(x)
|
||||
ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
|
||||
return x
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None, None, None]:
|
||||
x, = ctx.saved_tensors
|
||||
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
||||
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = x.to(torch.float32)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
||||
uncentered_var = (x ** 2).mean(dim=mean_dims)
|
||||
mean = x.mean(dim=mean_dims)
|
||||
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
||||
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
||||
|
||||
m = mean / stddev
|
||||
# part of loss that relates to mean / stddev
|
||||
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
||||
|
||||
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
||||
r_loss = (rms_clamped / rms).log().abs()
|
||||
|
||||
loss = (m_loss + r_loss).sum()
|
||||
|
||||
loss.backward()
|
||||
loss_grad = x.grad
|
||||
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||
|
||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||
|
||||
x_grad_float = x_grad.to(torch.float32)
|
||||
# scale each element of loss_grad by the absolute value of the corresponding
|
||||
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
||||
# (frame and dimension). later we can consider factored versions.
|
||||
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||
x_grad = x_grad_mod.to(x_grad.dtype)
|
||||
|
||||
return x_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Balancer(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||
each channel, that it is positive at least a proportion `threshold` of the
|
||||
time. It does this by multiplying negative derivative values by up to
|
||||
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||
interpolated from 1 at the threshold to those extremal values when none
|
||||
of the inputs are positive.
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels
|
||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||
min_positive: the minimum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_positive: the maximum, per channel, of the proportion of the time
|
||||
that (x > 0), above which we start to modify the derivatives.
|
||||
scale_gain_factor: determines the 'gain' with which we increase the
|
||||
change in gradient once the constraints on min_abs and max_abs
|
||||
are violated.
|
||||
min_abs: the minimum average-absolute-value difference from the mean
|
||||
value per channel, which we allow, before we start to modify
|
||||
the derivatives to prevent this.
|
||||
max_abs: the maximum average-absolute-value difference from the mean
|
||||
value per channel, which we allow, before we start to modify
|
||||
the derivatives to prevent this.
|
||||
prob: determines the minimum probability with which we modify the
|
||||
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
||||
on each forward(). This is done randomly to prevent all layers
|
||||
from doing it at the same time.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int,
|
||||
min_positive: FloatLike = 0.05,
|
||||
max_positive: FloatLike = 0.95,
|
||||
min_abs: FloatLike = 0.2,
|
||||
max_abs: FloatLike = 100.0,
|
||||
grad_scale: FloatLike = 0.02,
|
||||
prob: Optional[FloatLike] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if prob is None:
|
||||
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
|
||||
self.prob = prob
|
||||
# 5% of the time we will return and do nothing because memory usage is
|
||||
# too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.05)
|
||||
|
||||
# actually self.num_channels is no longer needed except for an assertion.
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
self.max_positive = max_positive
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
|
||||
return _no_op(x)
|
||||
|
||||
prob = float(self.prob)
|
||||
if random.random() < prob:
|
||||
# The following inner-functions convert from the way we historically specified
|
||||
# these limitations, as limits on the absolute value and the proportion of positive
|
||||
# values, to limits on the RMS value and the (mean / stddev).
|
||||
def _abs_to_rms(x):
|
||||
# for normally distributed data, if the expected absolute value is x, the
|
||||
# expected rms value will be sqrt(pi/2) * x.
|
||||
return 1.25331413732 * x
|
||||
def _proportion_positive_to_mean(x):
|
||||
def _atanh(x):
|
||||
eps = 1.0e-10
|
||||
# eps is to prevent crashes if x is exactly 0 or 1.
|
||||
# we'll just end up returning a fairly large value.
|
||||
return (math.log (1+x+eps) - math.log (1-x+eps)) / 2.
|
||||
def _approx_inverse_erf(x):
|
||||
# 1 / (sqrt(pi) * ln(2)),
|
||||
# see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
|
||||
# this approximation is extremely crude and gets progressively worse for
|
||||
# x very close to -1 or +1, but we mostly care about the "middle" region
|
||||
# e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
|
||||
# and math.erf(0.0407316414078772) = 0.045935330944660666,
|
||||
# which is pretty close to 0.05.
|
||||
return 0.8139535143 * _atanh(x)
|
||||
# first convert x from the range 0..1 to the range -1..1 which the error
|
||||
# function returns
|
||||
x = -1 + (2 * x)
|
||||
return _approx_inverse_erf(x)
|
||||
|
||||
min_mean = _proportion_positive_to_mean(float(self.min_positive))
|
||||
max_mean = _proportion_positive_to_mean(float(self.max_positive))
|
||||
min_rms = _abs_to_rms(float(self.min_abs))
|
||||
max_rms = _abs_to_rms(float(self.max_abs))
|
||||
grad_scale = float(self.grad_scale)
|
||||
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
|
||||
return BalancerFunction.apply(
|
||||
x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
|
||||
)
|
||||
else:
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
|
||||
name: str = None) -> Tensor:
|
||||
"""
|
||||
@ -1731,6 +1904,7 @@ def _test_activation_balancer_sign():
|
||||
max_positive=0.95,
|
||||
max_factor=0.2,
|
||||
min_abs=0.0,
|
||||
prob=1.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
@ -1742,6 +1916,31 @@ def _test_activation_balancer_sign():
|
||||
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = Balancer(
|
||||
probs.numel(),
|
||||
channel_dim=0,
|
||||
min_positive=0.05,
|
||||
max_positive=0.95,
|
||||
min_abs=0.0,
|
||||
prob=1.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_balancer_sign: x = ", x)
|
||||
print("_test_balancer_sign: y grad = ", y_grad)
|
||||
print("_test_balancer_sign: x grad = ", x.grad)
|
||||
|
||||
|
||||
|
||||
def _test_activation_balancer_magnitude():
|
||||
magnitudes = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
@ -1770,6 +1969,34 @@ def _test_activation_balancer_magnitude():
|
||||
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_balancer_magnitude():
|
||||
magnitudes = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
||||
-1
|
||||
)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = Balancer(
|
||||
magnitudes.numel(),
|
||||
channel_dim=0,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.2,
|
||||
max_abs=0.7,
|
||||
prob=1.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_balancer_magnitude: x = ", x)
|
||||
print("_test_balancer_magnitude: y grad = ", y_grad)
|
||||
print("_test_balancer_magnitude: x grad = ", x.grad)
|
||||
|
||||
|
||||
|
||||
def _test_basic_norm():
|
||||
num_channels = 128
|
||||
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
||||
@ -1862,7 +2089,9 @@ if __name__ == "__main__":
|
||||
_test_whiten()
|
||||
_test_max_eig()
|
||||
_test_activation_balancer_sign()
|
||||
_test_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
_test_tan_swish_deriv()
|
||||
|
||||
@ -26,6 +26,7 @@ import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
Balancer,
|
||||
BasicNorm,
|
||||
ConvNorm1d,
|
||||
ConvNorm2d,
|
||||
@ -456,7 +457,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
self.balancer = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
min_abs=1.0, max_abs=6.0,
|
||||
@ -1302,7 +1303,7 @@ class AttentionSqueeze(nn.Module):
|
||||
# instances of this module, the mean absolute values per channel are in
|
||||
# the range 0.1 to 0.4. We apply the upper limit of 0.4 at the
|
||||
# beginning, and make it looser over time.
|
||||
self.bottleneck_balancer = ActivationBalancer(
|
||||
self.bottleneck_balancer = Balancer(
|
||||
bottleneck_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.05,
|
||||
@ -1316,13 +1317,13 @@ class AttentionSqueeze(nn.Module):
|
||||
# many degrees of freedom for the scales of the various activations.
|
||||
# Make them run with very low probability, since only a small
|
||||
# application of these balancers should be enough to stop such "drift".
|
||||
self.scale_balancer = ActivationBalancer(
|
||||
self.scale_balancer = Balancer(
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
prob=_balancer_schedule(0.05),
|
||||
)
|
||||
self.activation_balancer = ActivationBalancer(
|
||||
self.activation_balancer = Balancer(
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
@ -1339,7 +1340,7 @@ class AttentionSqueeze(nn.Module):
|
||||
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
||||
bias=False, initial_scale=0.05)
|
||||
|
||||
self.out_balancer = ActivationBalancer(
|
||||
self.out_balancer = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
||||
@ -1396,7 +1397,7 @@ class FeedforwardModule(nn.Module):
|
||||
super(FeedforwardModule, self).__init__()
|
||||
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||
|
||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||
self.hidden_balancer = Balancer(feedforward_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
@ -1447,7 +1448,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
||||
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
||||
# before the sigmoid.
|
||||
self.balancer1 = ActivationBalancer(
|
||||
self.balancer1 = Balancer(
|
||||
hidden_channels, channel_dim=-1,
|
||||
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
||||
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
||||
@ -1473,7 +1474,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.balancer2 = ActivationBalancer(
|
||||
self.balancer2 = Balancer(
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
||||
@ -1566,7 +1567,7 @@ class ConvolutionModule(nn.Module):
|
||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.balancer1 = ActivationBalancer(
|
||||
self.balancer1 = Balancer(
|
||||
bottleneck_dim, channel_dim=-1,
|
||||
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
|
||||
max_positive=1.0,
|
||||
@ -1590,7 +1591,7 @@ class ConvolutionModule(nn.Module):
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.balancer2 = ActivationBalancer(
|
||||
self.balancer2 = Balancer(
|
||||
bottleneck_dim, channel_dim=1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
max_positive=1.0,
|
||||
@ -1695,7 +1696,7 @@ class ConvNeXt(nn.Module):
|
||||
out_channels=hidden_channels,
|
||||
kernel_size=1)
|
||||
|
||||
self.hidden_balancer = ActivationBalancer(hidden_channels,
|
||||
self.hidden_balancer = Balancer(hidden_channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
@ -1709,7 +1710,7 @@ class ConvNeXt(nn.Module):
|
||||
kernel_size=1,
|
||||
initial_scale=0.01)
|
||||
|
||||
self.out_balancer = ActivationBalancer(
|
||||
self.out_balancer = Balancer(
|
||||
channels, channel_dim=1,
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_abs=1.0, max_abs=6.0,
|
||||
@ -1800,7 +1801,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
padding=(0, 1), # (time, freq)
|
||||
),
|
||||
ScaleGrad(0.2),
|
||||
ActivationBalancer(layer1_channels,
|
||||
Balancer(layer1_channels,
|
||||
channel_dim=1,
|
||||
max_abs=1.0),
|
||||
SwooshR(),
|
||||
@ -1811,7 +1812,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
stride=2,
|
||||
padding=0,
|
||||
),
|
||||
ActivationBalancer(layer2_channels,
|
||||
Balancer(layer2_channels,
|
||||
channel_dim=1,
|
||||
max_abs=4.0),
|
||||
SwooshR(),
|
||||
@ -1833,7 +1834,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(1, 2), # (time, freq)
|
||||
),
|
||||
ActivationBalancer(layer3_channels,
|
||||
Balancer(layer3_channels,
|
||||
channel_dim=1,
|
||||
max_abs=4.0),
|
||||
SwooshR(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user