use nonzero threshold in DerivBalancer

This commit is contained in:
Daniel Povey 2022-03-10 23:24:55 +08:00
parent 425e274c82
commit 2fa9c636a4
3 changed files with 37 additions and 18 deletions

View File

@ -219,7 +219,7 @@ def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
x = x * (scale * speed).exp()
return x
class ExpScaleSwishFunction(torch.autograd.Function):
class SwishExpScaleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
@ -237,16 +237,16 @@ class ExpScaleSwishFunction(torch.autograd.Function):
return x.grad, scale.grad, None
class ExpScaleSwish(torch.nn.Module):
# combines ExpScale an Swish
# caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0)
class SwishExpScale(torch.nn.Module):
# combines ExpScale and a Swish (actually the ExpScale is after the Swish).
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
def __init__(self, *shape, speed: float = 1.0):
super(ExpScaleSwish, self).__init__()
super(SwishExpScale, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape))
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return ExpScaleSwishFunction.apply(x, self.scale, self.speed)
return SwishExpScaleFunction.apply(x, self.scale, self.speed)
# x = (x * torch.sigmoid(x))
# x = (x * torch.sigmoid(x))
# x = x * (self.scale * self.speed).exp()
@ -313,13 +313,15 @@ class ExpScaleRelu(torch.nn.Module):
class DerivBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, channel_dim: int,
threshold: 0.05, max_factor: 0.05,
epsilon: 1.0e-10) -> Tensor:
threshold: float = 0.05,
max_factor: float = 0.05,
zero: float = 0.02,
epsilon: float = 1.0e-10) -> Tensor:
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True)
proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
ctx.save_for_backward(factor)
@ -328,7 +330,7 @@ class DerivBalancerFunction(torch.autograd.Function):
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
factor, = ctx.saved_tensors
neg_delta_grad = x_grad.abs() * factor
if ctx.epsilon != 0.0:
@ -336,7 +338,7 @@ class DerivBalancerFunction(torch.autograd.Function):
deriv_is_zero = (sum_abs_grad == 0.0)
neg_delta_grad += ctx.epsilon * deriv_is_zero
return x_grad - neg_delta_grad, None, None, None, None
return x_grad - neg_delta_grad, None, None, None, None, None
class BasicNorm(torch.nn.Module):
@ -429,20 +431,37 @@ class DerivBalancer(torch.nn.Module):
When all grads are zero for a channel, this
module sets all the input derivatives for that channel to -epsilon; the
idea is to bring completely dead neurons back to life this way.
Args:
channel_dim: the dimension/axi corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
threshold: the threshold, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives,
e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.01].
zero: we use this value in the comparison (x > 0), i.e. we actually use
(x > zero). The reason for using a threshold slightly greater
than zero is that it will tend to prevent situations where the
inputs shrink close to zero and the nonlinearity (e.g. swish)
behaves like a linear function and we learn nothing.
"""
def __init__(self, channel_dim: int,
threshold: float = 0.05,
max_factor: float = 0.05,
max_factor: float = 0.02,
zero: float = 0.02,
epsilon: float = 1.0e-10):
super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim
self.threshold = threshold
self.max_factor = max_factor
self.zero = zero
self.epsilon = epsilon
def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
self.max_factor, self.epsilon)
self.max_factor, self.zero,
self.epsilon)
@ -455,7 +474,7 @@ def _test_exp_scale_swish():
x1 = torch.randn(50, 60).detach()
x2 = x1.detach()
m1 = ExpScaleSwish(50, 1, speed=4.0)
m1 = SwishExpScale(50, 1, speed=4.0)
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
x1.requires_grad = True
x2.requires_grad = True

View File

@ -19,7 +19,7 @@ import copy
import math
import warnings
from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm
from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm
import torch
from torch import Tensor, nn
@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0),
SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0),
SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2",
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved