Cosmetic changes/renaming things

This commit is contained in:
Daniel Povey 2022-03-16 19:27:45 +08:00
parent dfc75752c4
commit e838c192ef
2 changed files with 37 additions and 42 deletions

View File

@ -47,12 +47,12 @@ class Conv2dSubsampling(nn.Module):
ScaledConv2d( ScaledConv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2 in_channels=1, out_channels=odim, kernel_size=3, stride=2
), ),
DerivBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
ScaledConv2d( ScaledConv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2 in_channels=odim, out_channels=odim, kernel_size=3, stride=2
), ),
DerivBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
) )
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -61,7 +61,9 @@ class Conv2dSubsampling(nn.Module):
# needed. # needed.
self.out_norm = BasicNorm(odim, learn_eps=False) self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain median of output to be close to zero. # constrain median of output to be close to zero.
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -177,7 +179,7 @@ class VggSubsampling(nn.Module):
class DerivBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, def forward(ctx, x: Tensor,
channel_dim: int, channel_dim: int,
@ -428,44 +430,33 @@ class ScaledConv2d(nn.Conv2d):
class DerivBalancer(torch.nn.Module): class ActivationBalancer(torch.nn.Module):
""" """
Modifies the backpropped derivatives of a function to try to encourage, for Modifies the backpropped derivatives of a function to try to encourage, for
each channel, that it is positive at least a proportion `threshold` of the each channel, that it is positive at least a proportion `threshold` of the
time. It does this by multiplying negative derivative values by up to time. It does this by multiplying negative derivative values by up to
(1+max_factor), and positive derivative values by up to (1-max_factor), (1+max_factor), and positive derivative values by up to (1-max_factor),
interpolated from 0 at the threshold to those extremal values when none interpolated from 1 at the threshold to those extremal values when none
of the inputs are positive. of the inputs are positive.
When all grads are zero for a channel, this
module sets all the input derivatives for that channel to -epsilon; the
idea is to bring completely dead neurons back to life this way.
Args: Args:
channel_dim: the dimension/axi corresponding to the channel, e.g. channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time min_positive: the minimum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives. that (x > 0), below which we start to modify the derivatives.
max_positive: the maximum, per channel, of the proportion of the time max_positive: the maximum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives. that (x > 0), below which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives, max_factor: the maximum factor by which we modify the derivatives for
either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.01]. values in the range [0.98..1.02].
zero: we use this value in the comparison (x > 0), i.e. we actually use
(x > zero). The reason for using a threshold slightly greater
than zero is that it will tend to prevent situations where the
inputs shrink close to zero and the nonlinearity (e.g. swish)
behaves like a linear function and we learn nothing.
min_abs: the minimum average-absolute-value per channel, which min_abs: the minimum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent we allow, before we start to modify the derivatives to prevent
this. This is to prevent a failure mode where the activations this.
become so small that the nonlinearity effectively becomes linear,
which makes the module useless and it gets even smaller
to try to "turn it off" completely.
max_abs: the maximum average-absolute-value per channel, which max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent we allow, before we start to modify the derivatives to prevent
this. This is to prevent the possibility of activations getting this.
out of floating point numerical range (especially in half precision).
""" """
def __init__(self, channel_dim: int, def __init__(self, channel_dim: int,
min_positive: float = 0.05, min_positive: float = 0.05,
@ -473,7 +464,7 @@ class DerivBalancer(torch.nn.Module):
max_factor: float = 0.01, max_factor: float = 0.01,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0): max_abs: float = 100.0):
super(DerivBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.min_positive = min_positive self.min_positive = min_positive
self.max_positive = max_positive self.max_positive = max_positive
@ -482,10 +473,10 @@ class DerivBalancer(torch.nn.Module):
self.max_abs = max_abs self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, return ActivationBalancerFunction.apply(x, self.channel_dim,
self.min_positive, self.max_positive, self.min_positive, self.max_positive,
self.max_factor, self.min_abs, self.max_factor, self.min_abs,
self.max_abs) self.max_abs)
def _double_swish(x: Tensor) -> Tensor: def _double_swish(x: Tensor) -> Tensor:
@ -524,8 +515,8 @@ def _test_deriv_balancer_sign():
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
max_factor=0.2, min_abs=0.0) max_factor=0.2, min_abs=0.0)
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -542,10 +533,10 @@ def _test_deriv_balancer_magnitude():
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = DerivBalancer(channel_dim=0, m = ActivationBalancer(channel_dim=0,
min_positive=0.0, max_positive=1.0, min_positive=0.0, max_positive=1.0,
max_factor=0.2, max_factor=0.2,
min_abs=0.2, max_abs=0.8) min_abs=0.2, max_abs=0.8)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1), ActivationBalancer(channel_dim=-1),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1), ActivationBalancer(channel_dim=-1),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
@ -180,7 +180,9 @@ class ConformerEncoderLayer(nn.Module):
self.norm_final = BasicNorm(d_model) self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median). # try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) self.balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -858,8 +860,9 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0, self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
min_positive=0.05, max_positive=1.0) min_positive=0.05,
max_positive=1.0)
self.depthwise_conv = ScaledConv1d( self.depthwise_conv = ScaledConv1d(
channels, channels,
@ -871,8 +874,9 @@ class ConvolutionModule(nn.Module):
bias=bias, bias=bias,
) )
self.deriv_balancer2 = DerivBalancer(channel_dim=1, self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
min_positive=0.05, max_positive=1.0) min_positive=0.05,
max_positive=1.0)
self.activation = DoubleSwish() self.activation = DoubleSwish()