Cosmetic changes/renaming things

This commit is contained in:
Daniel Povey 2022-03-16 19:27:45 +08:00
parent dfc75752c4
commit e838c192ef
2 changed files with 37 additions and 42 deletions

View File

@ -47,12 +47,12 @@ class Conv2dSubsampling(nn.Module):
ScaledConv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
DerivBalancer(channel_dim=1),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
DerivBalancer(channel_dim=1),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -61,7 +61,9 @@ class Conv2dSubsampling(nn.Module):
# needed.
self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -177,7 +179,7 @@ class VggSubsampling(nn.Module):
class DerivBalancerFunction(torch.autograd.Function):
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor,
channel_dim: int,
@ -428,44 +430,33 @@ class ScaledConv2d(nn.Conv2d):
class DerivBalancer(torch.nn.Module):
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
each channel, that it is positive at least a proportion `threshold` of the
time. It does this by multiplying negative derivative values by up to
(1+max_factor), and positive derivative values by up to (1-max_factor),
interpolated from 0 at the threshold to those extremal values when none
interpolated from 1 at the threshold to those extremal values when none
of the inputs are positive.
When all grads are zero for a channel, this
module sets all the input derivatives for that channel to -epsilon; the
idea is to bring completely dead neurons back to life this way.
Args:
channel_dim: the dimension/axi corresponding to the channel, e.g.
channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_positive: the maximum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives,
max_factor: the maximum factor by which we modify the derivatives for
either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.01].
zero: we use this value in the comparison (x > 0), i.e. we actually use
(x > zero). The reason for using a threshold slightly greater
than zero is that it will tend to prevent situations where the
inputs shrink close to zero and the nonlinearity (e.g. swish)
behaves like a linear function and we learn nothing.
values in the range [0.98..1.02].
min_abs: the minimum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this. This is to prevent a failure mode where the activations
become so small that the nonlinearity effectively becomes linear,
which makes the module useless and it gets even smaller
to try to "turn it off" completely.
this.
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this. This is to prevent the possibility of activations getting
out of floating point numerical range (especially in half precision).
this.
"""
def __init__(self, channel_dim: int,
min_positive: float = 0.05,
@ -473,7 +464,7 @@ class DerivBalancer(torch.nn.Module):
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0):
super(DerivBalancer, self).__init__()
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
self.min_positive = min_positive
self.max_positive = max_positive
@ -482,10 +473,10 @@ class DerivBalancer(torch.nn.Module):
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim,
self.min_positive, self.max_positive,
self.max_factor, self.min_abs,
self.max_abs)
return ActivationBalancerFunction.apply(x, self.channel_dim,
self.min_positive, self.max_positive,
self.max_factor, self.min_abs,
self.max_abs)
def _double_swish(x: Tensor) -> Tensor:
@ -524,8 +515,8 @@ def _test_deriv_balancer_sign():
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
max_factor=0.2, min_abs=0.0)
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
max_factor=0.2, min_abs=0.0)
y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -542,10 +533,10 @@ def _test_deriv_balancer_magnitude():
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = x.detach()
x.requires_grad = True
m = DerivBalancer(channel_dim=0,
min_positive=0.0, max_positive=1.0,
max_factor=0.2,
min_abs=0.2, max_abs=0.8)
m = ActivationBalancer(channel_dim=0,
min_positive=0.0, max_positive=1.0,
max_factor=0.2,
min_abs=0.2, max_abs=0.8)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))

View File

@ -19,7 +19,7 @@ import copy
import math
import warnings
from typing import Optional, Tuple, Sequence
from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch
from torch import Tensor, nn
@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
@ -180,7 +180,9 @@ class ConformerEncoderLayer(nn.Module):
self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
self.balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
self.dropout = nn.Dropout(dropout)
@ -858,8 +860,9 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0,
min_positive=0.05, max_positive=1.0)
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
min_positive=0.05,
max_positive=1.0)
self.depthwise_conv = ScaledConv1d(
channels,
@ -871,8 +874,9 @@ class ConvolutionModule(nn.Module):
bias=bias,
)
self.deriv_balancer2 = DerivBalancer(channel_dim=1,
min_positive=0.05, max_positive=1.0)
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
min_positive=0.05,
max_positive=1.0)
self.activation = DoubleSwish()