use nonzero threshold in DerivBalancer

This commit is contained in:
Daniel Povey 2022-03-10 23:24:55 +08:00
parent 425e274c82
commit 2fa9c636a4
3 changed files with 37 additions and 18 deletions

View File

@ -219,7 +219,7 @@ def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
x = x * (scale * speed).exp() x = x * (scale * speed).exp()
return x return x
class ExpScaleSwishFunction(torch.autograd.Function): class SwishExpScaleFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach()) ctx.save_for_backward(x.detach(), scale.detach())
@ -237,16 +237,16 @@ class ExpScaleSwishFunction(torch.autograd.Function):
return x.grad, scale.grad, None return x.grad, scale.grad, None
class ExpScaleSwish(torch.nn.Module): class SwishExpScale(torch.nn.Module):
# combines ExpScale an Swish # combines ExpScale and a Swish (actually the ExpScale is after the Swish).
# caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0) # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
def __init__(self, *shape, speed: float = 1.0): def __init__(self, *shape, speed: float = 1.0):
super(ExpScaleSwish, self).__init__() super(SwishExpScale, self).__init__()
self.scale = nn.Parameter(torch.zeros(*shape)) self.scale = nn.Parameter(torch.zeros(*shape))
self.speed = speed self.speed = speed
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return ExpScaleSwishFunction.apply(x, self.scale, self.speed) return SwishExpScaleFunction.apply(x, self.scale, self.speed)
# x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x))
# x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x))
# x = x * (self.scale * self.speed).exp() # x = x * (self.scale * self.speed).exp()
@ -313,13 +313,15 @@ class ExpScaleRelu(torch.nn.Module):
class DerivBalancerFunction(torch.autograd.Function): class DerivBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, channel_dim: int, def forward(ctx, x: Tensor, channel_dim: int,
threshold: 0.05, max_factor: 0.05, threshold: float = 0.05,
epsilon: 1.0e-10) -> Tensor: max_factor: float = 0.05,
zero: float = 0.02,
epsilon: float = 1.0e-10) -> Tensor:
if x.requires_grad: if x.requires_grad:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True) proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold) factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
ctx.save_for_backward(factor) ctx.save_for_backward(factor)
@ -328,7 +330,7 @@ class DerivBalancerFunction(torch.autograd.Function):
return x return x
@staticmethod @staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
factor, = ctx.saved_tensors factor, = ctx.saved_tensors
neg_delta_grad = x_grad.abs() * factor neg_delta_grad = x_grad.abs() * factor
if ctx.epsilon != 0.0: if ctx.epsilon != 0.0:
@ -336,7 +338,7 @@ class DerivBalancerFunction(torch.autograd.Function):
deriv_is_zero = (sum_abs_grad == 0.0) deriv_is_zero = (sum_abs_grad == 0.0)
neg_delta_grad += ctx.epsilon * deriv_is_zero neg_delta_grad += ctx.epsilon * deriv_is_zero
return x_grad - neg_delta_grad, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
@ -429,20 +431,37 @@ class DerivBalancer(torch.nn.Module):
When all grads are zero for a channel, this When all grads are zero for a channel, this
module sets all the input derivatives for that channel to -epsilon; the module sets all the input derivatives for that channel to -epsilon; the
idea is to bring completely dead neurons back to life this way. idea is to bring completely dead neurons back to life this way.
Args:
channel_dim: the dimension/axi corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
threshold: the threshold, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives,
e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.01].
zero: we use this value in the comparison (x > 0), i.e. we actually use
(x > zero). The reason for using a threshold slightly greater
than zero is that it will tend to prevent situations where the
inputs shrink close to zero and the nonlinearity (e.g. swish)
behaves like a linear function and we learn nothing.
""" """
def __init__(self, channel_dim: int, def __init__(self, channel_dim: int,
threshold: float = 0.05, threshold: float = 0.05,
max_factor: float = 0.05, max_factor: float = 0.02,
zero: float = 0.02,
epsilon: float = 1.0e-10): epsilon: float = 1.0e-10):
super(DerivBalancer, self).__init__() super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.threshold = threshold self.threshold = threshold
self.max_factor = max_factor self.max_factor = max_factor
self.zero = zero
self.epsilon = epsilon self.epsilon = epsilon
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
self.max_factor, self.epsilon) self.max_factor, self.zero,
self.epsilon)
@ -455,7 +474,7 @@ def _test_exp_scale_swish():
x1 = torch.randn(50, 60).detach() x1 = torch.randn(50, 60).detach()
x2 = x1.detach() x2 = x1.detach()
m1 = ExpScaleSwish(50, 1, speed=4.0) m1 = SwishExpScale(50, 1, speed=4.0)
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
x1.requires_grad = True x1.requires_grad = True
x2.requires_grad = True x2.requires_grad = True

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0), SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
ExpScaleSwish(dim_feedforward, speed=20.0), SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved