mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp795' into scaled_adam_exp798
This commit is contained in:
commit
0c3530a6fd
@ -23,6 +23,7 @@ import logging
|
|||||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -30,7 +31,6 @@ from torch.nn import Embedding as ScaledEmbedding
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduledFloat(torch.nn.Module):
|
class ScheduledFloat(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||||
@ -975,6 +975,179 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BalancerFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
x: Tensor,
|
||||||
|
min_mean: float,
|
||||||
|
max_mean: float,
|
||||||
|
min_rms: float,
|
||||||
|
max_rms: float,
|
||||||
|
grad_scale: float,
|
||||||
|
channel_dim: int,
|
||||||
|
) -> Tensor:
|
||||||
|
if channel_dim < 0:
|
||||||
|
channel_dim += x.ndim
|
||||||
|
ctx.channel_dim = channel_dim
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(
|
||||||
|
ctx, x_grad: Tensor
|
||||||
|
) -> Tuple[Tensor, None, None, None, None, None]:
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
||||||
|
uncentered_var = (x ** 2).mean(dim=mean_dims)
|
||||||
|
mean = x.mean(dim=mean_dims)
|
||||||
|
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
||||||
|
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
||||||
|
|
||||||
|
m = mean / stddev
|
||||||
|
# part of loss that relates to mean / stddev
|
||||||
|
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
||||||
|
|
||||||
|
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
||||||
|
r_loss = (rms_clamped / rms).log().abs()
|
||||||
|
|
||||||
|
loss = (m_loss + r_loss).sum()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
loss_grad = x.grad
|
||||||
|
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||||
|
|
||||||
|
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||||
|
|
||||||
|
x_grad_float = x_grad.to(torch.float32)
|
||||||
|
# scale each element of loss_grad by the absolute value of the corresponding
|
||||||
|
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
||||||
|
# (frame and dimension). later we can consider factored versions.
|
||||||
|
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||||
|
x_grad = x_grad_mod.to(x_grad.dtype)
|
||||||
|
|
||||||
|
return x_grad, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class Balancer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||||
|
each channel, that it is positive at least a proportion `threshold` of the
|
||||||
|
time. It does this by multiplying negative derivative values by up to
|
||||||
|
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||||
|
interpolated from 1 at the threshold to those extremal values when none
|
||||||
|
of the inputs are positive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_channels: the number of channels
|
||||||
|
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||||
|
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||||
|
min_positive: the minimum, per channel, of the proportion of the time
|
||||||
|
that (x > 0), below which we start to modify the derivatives.
|
||||||
|
max_positive: the maximum, per channel, of the proportion of the time
|
||||||
|
that (x > 0), above which we start to modify the derivatives.
|
||||||
|
scale_gain_factor: determines the 'gain' with which we increase the
|
||||||
|
change in gradient once the constraints on min_abs and max_abs
|
||||||
|
are violated.
|
||||||
|
min_abs: the minimum average-absolute-value difference from the mean
|
||||||
|
value per channel, which we allow, before we start to modify
|
||||||
|
the derivatives to prevent this.
|
||||||
|
max_abs: the maximum average-absolute-value difference from the mean
|
||||||
|
value per channel, which we allow, before we start to modify
|
||||||
|
the derivatives to prevent this.
|
||||||
|
prob: determines the minimum probability with which we modify the
|
||||||
|
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
||||||
|
on each forward(). This is done randomly to prevent all layers
|
||||||
|
from doing it at the same time.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels: int,
|
||||||
|
channel_dim: int,
|
||||||
|
min_positive: FloatLike = 0.05,
|
||||||
|
max_positive: FloatLike = 0.95,
|
||||||
|
min_abs: FloatLike = 0.2,
|
||||||
|
max_abs: FloatLike = 100.0,
|
||||||
|
grad_scale: FloatLike = 0.04,
|
||||||
|
prob: Optional[FloatLike] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if prob is None:
|
||||||
|
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
|
||||||
|
self.prob = prob
|
||||||
|
# 5% of the time we will return and do nothing because memory usage is
|
||||||
|
# too high.
|
||||||
|
self.mem_cutoff = CutoffEstimator(0.05)
|
||||||
|
|
||||||
|
# actually self.num_channels is no longer needed except for an assertion.
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.channel_dim = channel_dim
|
||||||
|
self.min_positive = min_positive
|
||||||
|
self.max_positive = max_positive
|
||||||
|
self.min_abs = min_abs
|
||||||
|
self.max_abs = max_abs
|
||||||
|
self.grad_scale = grad_scale
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||||
|
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
|
||||||
|
return _no_op(x)
|
||||||
|
|
||||||
|
prob = float(self.prob)
|
||||||
|
if random.random() < prob:
|
||||||
|
# The following inner-functions convert from the way we historically specified
|
||||||
|
# these limitations, as limits on the absolute value and the proportion of positive
|
||||||
|
# values, to limits on the RMS value and the (mean / stddev).
|
||||||
|
def _abs_to_rms(x):
|
||||||
|
# for normally distributed data, if the expected absolute value is x, the
|
||||||
|
# expected rms value will be sqrt(pi/2) * x.
|
||||||
|
return 1.25331413732 * x
|
||||||
|
def _proportion_positive_to_mean(x):
|
||||||
|
def _atanh(x):
|
||||||
|
eps = 1.0e-10
|
||||||
|
# eps is to prevent crashes if x is exactly 0 or 1.
|
||||||
|
# we'll just end up returning a fairly large value.
|
||||||
|
return (math.log (1+x+eps) - math.log (1-x+eps)) / 2.
|
||||||
|
def _approx_inverse_erf(x):
|
||||||
|
# 1 / (sqrt(pi) * ln(2)),
|
||||||
|
# see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
|
||||||
|
# this approximation is extremely crude and gets progressively worse for
|
||||||
|
# x very close to -1 or +1, but we mostly care about the "middle" region
|
||||||
|
# e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
|
||||||
|
# and math.erf(0.0407316414078772) = 0.045935330944660666,
|
||||||
|
# which is pretty close to 0.05.
|
||||||
|
return 0.8139535143 * _atanh(x)
|
||||||
|
# first convert x from the range 0..1 to the range -1..1 which the error
|
||||||
|
# function returns
|
||||||
|
x = -1 + (2 * x)
|
||||||
|
return _approx_inverse_erf(x)
|
||||||
|
|
||||||
|
min_mean = _proportion_positive_to_mean(float(self.min_positive))
|
||||||
|
max_mean = _proportion_positive_to_mean(float(self.max_positive))
|
||||||
|
min_rms = _abs_to_rms(float(self.min_abs))
|
||||||
|
max_rms = _abs_to_rms(float(self.max_abs))
|
||||||
|
grad_scale = float(self.grad_scale)
|
||||||
|
|
||||||
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
|
|
||||||
|
return BalancerFunction.apply(
|
||||||
|
x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _no_op(x)
|
||||||
|
|
||||||
|
|
||||||
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
|
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
|
||||||
name: str = None) -> Tensor:
|
name: str = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -1731,6 +1904,7 @@ def _test_activation_balancer_sign():
|
|||||||
max_positive=0.95,
|
max_positive=0.95,
|
||||||
max_factor=0.2,
|
max_factor=0.2,
|
||||||
min_abs=0.0,
|
min_abs=0.0,
|
||||||
|
prob=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
@ -1742,6 +1916,31 @@ def _test_activation_balancer_sign():
|
|||||||
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_balancer_sign():
|
||||||
|
probs = torch.arange(0, 1, 0.01)
|
||||||
|
N = 1000
|
||||||
|
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
m = Balancer(
|
||||||
|
probs.numel(),
|
||||||
|
channel_dim=0,
|
||||||
|
min_positive=0.05,
|
||||||
|
max_positive=0.95,
|
||||||
|
min_abs=0.0,
|
||||||
|
prob=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
y.backward(gradient=y_grad)
|
||||||
|
print("_test_balancer_sign: x = ", x)
|
||||||
|
print("_test_balancer_sign: y grad = ", y_grad)
|
||||||
|
print("_test_balancer_sign: x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_magnitude():
|
def _test_activation_balancer_magnitude():
|
||||||
magnitudes = torch.arange(0, 1, 0.01)
|
magnitudes = torch.arange(0, 1, 0.01)
|
||||||
N = 1000
|
N = 1000
|
||||||
@ -1770,6 +1969,34 @@ def _test_activation_balancer_magnitude():
|
|||||||
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_balancer_magnitude():
|
||||||
|
magnitudes = torch.arange(0, 1, 0.01)
|
||||||
|
N = 1000
|
||||||
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
m = Balancer(
|
||||||
|
magnitudes.numel(),
|
||||||
|
channel_dim=0,
|
||||||
|
min_positive=0.0,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.2,
|
||||||
|
max_abs=0.7,
|
||||||
|
prob=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
y.backward(gradient=y_grad)
|
||||||
|
print("_test_balancer_magnitude: x = ", x)
|
||||||
|
print("_test_balancer_magnitude: y grad = ", y_grad)
|
||||||
|
print("_test_balancer_magnitude: x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_basic_norm():
|
def _test_basic_norm():
|
||||||
num_channels = 128
|
num_channels = 128
|
||||||
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
||||||
@ -1862,7 +2089,9 @@ if __name__ == "__main__":
|
|||||||
_test_whiten()
|
_test_whiten()
|
||||||
_test_max_eig()
|
_test_max_eig()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
|
_test_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
|
_test_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
_test_double_swish_deriv()
|
_test_double_swish_deriv()
|
||||||
_test_tan_swish_deriv()
|
_test_tan_swish_deriv()
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import random
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from scaling import (
|
||||||
ActivationBalancer,
|
ActivationBalancer,
|
||||||
|
Balancer,
|
||||||
BasicNorm,
|
BasicNorm,
|
||||||
ConvNorm1d,
|
ConvNorm1d,
|
||||||
ConvNorm2d,
|
ConvNorm2d,
|
||||||
@ -456,7 +457,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = ActivationBalancer(
|
self.balancer = Balancer(
|
||||||
embed_dim, channel_dim=-1,
|
embed_dim, channel_dim=-1,
|
||||||
min_positive=0.45, max_positive=0.55,
|
min_positive=0.45, max_positive=0.55,
|
||||||
min_abs=1.0, max_abs=6.0,
|
min_abs=1.0, max_abs=6.0,
|
||||||
@ -1302,7 +1303,7 @@ class AttentionSqueeze(nn.Module):
|
|||||||
# instances of this module, the mean absolute values per channel are in
|
# instances of this module, the mean absolute values per channel are in
|
||||||
# the range 0.1 to 0.4. We apply the upper limit of 0.4 at the
|
# the range 0.1 to 0.4. We apply the upper limit of 0.4 at the
|
||||||
# beginning, and make it looser over time.
|
# beginning, and make it looser over time.
|
||||||
self.bottleneck_balancer = ActivationBalancer(
|
self.bottleneck_balancer = Balancer(
|
||||||
bottleneck_dim, channel_dim=-1,
|
bottleneck_dim, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.05,
|
min_abs=0.05,
|
||||||
@ -1316,13 +1317,13 @@ class AttentionSqueeze(nn.Module):
|
|||||||
# many degrees of freedom for the scales of the various activations.
|
# many degrees of freedom for the scales of the various activations.
|
||||||
# Make them run with very low probability, since only a small
|
# Make them run with very low probability, since only a small
|
||||||
# application of these balancers should be enough to stop such "drift".
|
# application of these balancers should be enough to stop such "drift".
|
||||||
self.scale_balancer = ActivationBalancer(
|
self.scale_balancer = Balancer(
|
||||||
hidden_dim, channel_dim=-1,
|
hidden_dim, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
prob=_balancer_schedule(0.05),
|
prob=_balancer_schedule(0.05),
|
||||||
)
|
)
|
||||||
self.activation_balancer = ActivationBalancer(
|
self.activation_balancer = Balancer(
|
||||||
hidden_dim, channel_dim=-1,
|
hidden_dim, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
@ -1339,7 +1340,7 @@ class AttentionSqueeze(nn.Module):
|
|||||||
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
||||||
bias=False, initial_scale=0.05)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
self.out_balancer = ActivationBalancer(
|
self.out_balancer = Balancer(
|
||||||
embed_dim, channel_dim=-1,
|
embed_dim, channel_dim=-1,
|
||||||
min_positive=0.3, max_positive=0.7,
|
min_positive=0.3, max_positive=0.7,
|
||||||
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
||||||
@ -1396,7 +1397,7 @@ class FeedforwardModule(nn.Module):
|
|||||||
super(FeedforwardModule, self).__init__()
|
super(FeedforwardModule, self).__init__()
|
||||||
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||||
|
|
||||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
self.hidden_balancer = Balancer(feedforward_dim,
|
||||||
channel_dim=-1,
|
channel_dim=-1,
|
||||||
min_positive=0.3,
|
min_positive=0.3,
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
@ -1447,7 +1448,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
||||||
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
||||||
# before the sigmoid.
|
# before the sigmoid.
|
||||||
self.balancer1 = ActivationBalancer(
|
self.balancer1 = Balancer(
|
||||||
hidden_channels, channel_dim=-1,
|
hidden_channels, channel_dim=-1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
||||||
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
||||||
@ -1473,7 +1474,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
self.balancer2 = ActivationBalancer(
|
self.balancer2 = Balancer(
|
||||||
channels, channel_dim=-1,
|
channels, channel_dim=-1,
|
||||||
min_positive=0.3, max_positive=0.7,
|
min_positive=0.3, max_positive=0.7,
|
||||||
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
||||||
@ -1566,7 +1567,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.balancer1 = ActivationBalancer(
|
self.balancer1 = Balancer(
|
||||||
bottleneck_dim, channel_dim=-1,
|
bottleneck_dim, channel_dim=-1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
|
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
@ -1590,7 +1591,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.balancer2 = ActivationBalancer(
|
self.balancer2 = Balancer(
|
||||||
bottleneck_dim, channel_dim=1,
|
bottleneck_dim, channel_dim=1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
@ -1695,7 +1696,7 @@ class ConvNeXt(nn.Module):
|
|||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
kernel_size=1)
|
kernel_size=1)
|
||||||
|
|
||||||
self.hidden_balancer = ActivationBalancer(hidden_channels,
|
self.hidden_balancer = Balancer(hidden_channels,
|
||||||
channel_dim=1,
|
channel_dim=1,
|
||||||
min_positive=0.3,
|
min_positive=0.3,
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
@ -1709,7 +1710,7 @@ class ConvNeXt(nn.Module):
|
|||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
initial_scale=0.01)
|
initial_scale=0.01)
|
||||||
|
|
||||||
self.out_balancer = ActivationBalancer(
|
self.out_balancer = Balancer(
|
||||||
channels, channel_dim=1,
|
channels, channel_dim=1,
|
||||||
min_positive=0.4, max_positive=0.6,
|
min_positive=0.4, max_positive=0.6,
|
||||||
min_abs=1.0, max_abs=6.0,
|
min_abs=1.0, max_abs=6.0,
|
||||||
@ -1800,7 +1801,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
padding=(0, 1), # (time, freq)
|
padding=(0, 1), # (time, freq)
|
||||||
),
|
),
|
||||||
ScaleGrad(0.2),
|
ScaleGrad(0.2),
|
||||||
ActivationBalancer(layer1_channels,
|
Balancer(layer1_channels,
|
||||||
channel_dim=1,
|
channel_dim=1,
|
||||||
max_abs=1.0),
|
max_abs=1.0),
|
||||||
SwooshR(),
|
SwooshR(),
|
||||||
@ -1811,7 +1812,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
stride=2,
|
stride=2,
|
||||||
padding=0,
|
padding=0,
|
||||||
),
|
),
|
||||||
ActivationBalancer(layer2_channels,
|
Balancer(layer2_channels,
|
||||||
channel_dim=1,
|
channel_dim=1,
|
||||||
max_abs=4.0),
|
max_abs=4.0),
|
||||||
SwooshR(),
|
SwooshR(),
|
||||||
@ -1833,7 +1834,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(1, 2), # (time, freq)
|
stride=(1, 2), # (time, freq)
|
||||||
),
|
),
|
||||||
ActivationBalancer(layer3_channels,
|
Balancer(layer3_channels,
|
||||||
channel_dim=1,
|
channel_dim=1,
|
||||||
max_abs=4.0),
|
max_abs=4.0),
|
||||||
SwooshR(),
|
SwooshR(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user