mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
Cosmetic changes/renaming things
This commit is contained in:
parent
dfc75752c4
commit
e838c192ef
@ -47,12 +47,12 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
@ -61,7 +61,9 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# needed.
|
# needed.
|
||||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||||
# constrain median of output to be close to zero.
|
# constrain median of output to be close to zero.
|
||||||
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
|
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||||
|
min_positive=0.45,
|
||||||
|
max_positive=0.55)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -177,7 +179,7 @@ class VggSubsampling(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DerivBalancerFunction(torch.autograd.Function):
|
class ActivationBalancerFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor,
|
def forward(ctx, x: Tensor,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
@ -428,44 +430,33 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DerivBalancer(torch.nn.Module):
|
class ActivationBalancer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||||
each channel, that it is positive at least a proportion `threshold` of the
|
each channel, that it is positive at least a proportion `threshold` of the
|
||||||
time. It does this by multiplying negative derivative values by up to
|
time. It does this by multiplying negative derivative values by up to
|
||||||
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||||
interpolated from 0 at the threshold to those extremal values when none
|
interpolated from 1 at the threshold to those extremal values when none
|
||||||
of the inputs are positive.
|
of the inputs are positive.
|
||||||
|
|
||||||
When all grads are zero for a channel, this
|
|
||||||
module sets all the input derivatives for that channel to -epsilon; the
|
|
||||||
idea is to bring completely dead neurons back to life this way.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channel_dim: the dimension/axi corresponding to the channel, e.g.
|
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||||
min_positive: the minimum, per channel, of the proportion of the time
|
min_positive: the minimum, per channel, of the proportion of the time
|
||||||
that (x > 0), below which we start to modify the derivatives.
|
that (x > 0), below which we start to modify the derivatives.
|
||||||
max_positive: the maximum, per channel, of the proportion of the time
|
max_positive: the maximum, per channel, of the proportion of the time
|
||||||
that (x > 0), below which we start to modify the derivatives.
|
that (x > 0), below which we start to modify the derivatives.
|
||||||
max_factor: the maximum factor by which we modify the derivatives,
|
max_factor: the maximum factor by which we modify the derivatives for
|
||||||
|
either the sign constraint or the magnitude constraint;
|
||||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||||
values in the range [0.98..1.01].
|
values in the range [0.98..1.02].
|
||||||
zero: we use this value in the comparison (x > 0), i.e. we actually use
|
|
||||||
(x > zero). The reason for using a threshold slightly greater
|
|
||||||
than zero is that it will tend to prevent situations where the
|
|
||||||
inputs shrink close to zero and the nonlinearity (e.g. swish)
|
|
||||||
behaves like a linear function and we learn nothing.
|
|
||||||
min_abs: the minimum average-absolute-value per channel, which
|
min_abs: the minimum average-absolute-value per channel, which
|
||||||
we allow, before we start to modify the derivatives to prevent
|
we allow, before we start to modify the derivatives to prevent
|
||||||
this. This is to prevent a failure mode where the activations
|
this.
|
||||||
become so small that the nonlinearity effectively becomes linear,
|
|
||||||
which makes the module useless and it gets even smaller
|
|
||||||
to try to "turn it off" completely.
|
|
||||||
max_abs: the maximum average-absolute-value per channel, which
|
max_abs: the maximum average-absolute-value per channel, which
|
||||||
we allow, before we start to modify the derivatives to prevent
|
we allow, before we start to modify the derivatives to prevent
|
||||||
this. This is to prevent the possibility of activations getting
|
this.
|
||||||
out of floating point numerical range (especially in half precision).
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel_dim: int,
|
def __init__(self, channel_dim: int,
|
||||||
min_positive: float = 0.05,
|
min_positive: float = 0.05,
|
||||||
@ -473,7 +464,7 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
max_factor: float = 0.01,
|
max_factor: float = 0.01,
|
||||||
min_abs: float = 0.2,
|
min_abs: float = 0.2,
|
||||||
max_abs: float = 100.0):
|
max_abs: float = 100.0):
|
||||||
super(DerivBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.min_positive = min_positive
|
self.min_positive = min_positive
|
||||||
self.max_positive = max_positive
|
self.max_positive = max_positive
|
||||||
@ -482,10 +473,10 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return DerivBalancerFunction.apply(x, self.channel_dim,
|
return ActivationBalancerFunction.apply(x, self.channel_dim,
|
||||||
self.min_positive, self.max_positive,
|
self.min_positive, self.max_positive,
|
||||||
self.max_factor, self.min_abs,
|
self.max_factor, self.min_abs,
|
||||||
self.max_abs)
|
self.max_abs)
|
||||||
|
|
||||||
|
|
||||||
def _double_swish(x: Tensor) -> Tensor:
|
def _double_swish(x: Tensor) -> Tensor:
|
||||||
@ -524,8 +515,8 @@ def _test_deriv_balancer_sign():
|
|||||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
||||||
max_factor=0.2, min_abs=0.0)
|
max_factor=0.2, min_abs=0.0)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
|
|
||||||
@ -542,10 +533,10 @@ def _test_deriv_balancer_magnitude():
|
|||||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = DerivBalancer(channel_dim=0,
|
m = ActivationBalancer(channel_dim=0,
|
||||||
min_positive=0.0, max_positive=1.0,
|
min_positive=0.0, max_positive=1.0,
|
||||||
max_factor=0.2,
|
max_factor=0.2,
|
||||||
min_abs=0.2, max_abs=0.8)
|
min_abs=0.2, max_abs=0.8)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple, Sequence
|
||||||
from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1),
|
ActivationBalancer(channel_dim=-1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1),
|
ActivationBalancer(channel_dim=-1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
@ -180,7 +180,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
|
self.balancer = ActivationBalancer(channel_dim=-1,
|
||||||
|
min_positive=0.45,
|
||||||
|
max_positive=0.55)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -858,8 +860,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0,
|
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
|
||||||
min_positive=0.05, max_positive=1.0)
|
min_positive=0.05,
|
||||||
|
max_positive=1.0)
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
@ -871,8 +874,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deriv_balancer2 = DerivBalancer(channel_dim=1,
|
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
|
||||||
min_positive=0.05, max_positive=1.0)
|
min_positive=0.05,
|
||||||
|
max_positive=1.0)
|
||||||
|
|
||||||
self.activation = DoubleSwish()
|
self.activation = DoubleSwish()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user