mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Cosmetic changes/renaming things
This commit is contained in:
parent
dfc75752c4
commit
e838c192ef
@ -47,12 +47,12 @@ class Conv2dSubsampling(nn.Module):
|
||||
ScaledConv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
DerivBalancer(channel_dim=1),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
DerivBalancer(channel_dim=1),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
@ -61,7 +61,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
|
||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -177,7 +179,7 @@ class VggSubsampling(nn.Module):
|
||||
|
||||
|
||||
|
||||
class DerivBalancerFunction(torch.autograd.Function):
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor,
|
||||
channel_dim: int,
|
||||
@ -428,44 +430,33 @@ class ScaledConv2d(nn.Conv2d):
|
||||
|
||||
|
||||
|
||||
class DerivBalancer(torch.nn.Module):
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||
each channel, that it is positive at least a proportion `threshold` of the
|
||||
time. It does this by multiplying negative derivative values by up to
|
||||
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||
interpolated from 0 at the threshold to those extremal values when none
|
||||
interpolated from 1 at the threshold to those extremal values when none
|
||||
of the inputs are positive.
|
||||
|
||||
When all grads are zero for a channel, this
|
||||
module sets all the input derivatives for that channel to -epsilon; the
|
||||
idea is to bring completely dead neurons back to life this way.
|
||||
|
||||
Args:
|
||||
channel_dim: the dimension/axi corresponding to the channel, e.g.
|
||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||
min_positive: the minimum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_positive: the maximum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_factor: the maximum factor by which we modify the derivatives,
|
||||
max_factor: the maximum factor by which we modify the derivatives for
|
||||
either the sign constraint or the magnitude constraint;
|
||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||
values in the range [0.98..1.01].
|
||||
zero: we use this value in the comparison (x > 0), i.e. we actually use
|
||||
(x > zero). The reason for using a threshold slightly greater
|
||||
than zero is that it will tend to prevent situations where the
|
||||
inputs shrink close to zero and the nonlinearity (e.g. swish)
|
||||
behaves like a linear function and we learn nothing.
|
||||
values in the range [0.98..1.02].
|
||||
min_abs: the minimum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this. This is to prevent a failure mode where the activations
|
||||
become so small that the nonlinearity effectively becomes linear,
|
||||
which makes the module useless and it gets even smaller
|
||||
to try to "turn it off" completely.
|
||||
this.
|
||||
max_abs: the maximum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this. This is to prevent the possibility of activations getting
|
||||
out of floating point numerical range (especially in half precision).
|
||||
this.
|
||||
"""
|
||||
def __init__(self, channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
@ -473,7 +464,7 @@ class DerivBalancer(torch.nn.Module):
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0):
|
||||
super(DerivBalancer, self).__init__()
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
self.max_positive = max_positive
|
||||
@ -482,10 +473,10 @@ class DerivBalancer(torch.nn.Module):
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return DerivBalancerFunction.apply(x, self.channel_dim,
|
||||
self.min_positive, self.max_positive,
|
||||
self.max_factor, self.min_abs,
|
||||
self.max_abs)
|
||||
return ActivationBalancerFunction.apply(x, self.channel_dim,
|
||||
self.min_positive, self.max_positive,
|
||||
self.max_factor, self.min_abs,
|
||||
self.max_abs)
|
||||
|
||||
|
||||
def _double_swish(x: Tensor) -> Tensor:
|
||||
@ -524,8 +515,8 @@ def _test_deriv_balancer_sign():
|
||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
||||
max_factor=0.2, min_abs=0.0)
|
||||
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
||||
max_factor=0.2, min_abs=0.0)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
|
||||
@ -542,10 +533,10 @@ def _test_deriv_balancer_magnitude():
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = DerivBalancer(channel_dim=0,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
max_factor=0.2,
|
||||
min_abs=0.2, max_abs=0.8)
|
||||
m = ActivationBalancer(channel_dim=0,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
max_factor=0.2,
|
||||
min_abs=0.2, max_abs=0.8)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
|
||||
|
@ -19,7 +19,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Sequence
|
||||
from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
||||
from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
DerivBalancer(channel_dim=-1),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
DerivBalancer(channel_dim=-1),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
@ -180,7 +180,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
|
||||
self.balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -858,8 +860,9 @@ class ConvolutionModule(nn.Module):
|
||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0,
|
||||
min_positive=0.05, max_positive=1.0)
|
||||
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
|
||||
min_positive=0.05,
|
||||
max_positive=1.0)
|
||||
|
||||
self.depthwise_conv = ScaledConv1d(
|
||||
channels,
|
||||
@ -871,8 +874,9 @@ class ConvolutionModule(nn.Module):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.deriv_balancer2 = DerivBalancer(channel_dim=1,
|
||||
min_positive=0.05, max_positive=1.0)
|
||||
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
|
||||
min_positive=0.05,
|
||||
max_positive=1.0)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user