Add deriv-balancing code

This commit is contained in:
Daniel Povey 2022-03-04 20:19:11 +08:00
parent eb3ed54202
commit 6252282fd0
3 changed files with 93 additions and 2 deletions

View File

@ -47,11 +47,15 @@ class Conv2dSubsampling(nn.Module):
nn.Conv2d( nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2 in_channels=1, out_channels=odim, kernel_size=3, stride=2
), ),
DerivBalancer(channel_dim=1, threshold=0.02,
max_factor=0.02),
nn.ReLU(), nn.ReLU(),
ExpScale(odim, 1, 1, speed=20.0), ExpScale(odim, 1, 1, speed=20.0),
nn.Conv2d( nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2 in_channels=odim, out_channels=odim, kernel_size=3, stride=2
), ),
DerivBalancer(channel_dim=1, threshold=0.02,
max_factor=0.02),
nn.ReLU(), nn.ReLU(),
ExpScale(odim, 1, 1, speed=20.0), ExpScale(odim, 1, 1, speed=20.0),
) )
@ -248,6 +252,68 @@ class ExpScaleSwish(torch.nn.Module):
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
# return x * (self.scale * self.speed).exp() # return x * (self.scale * self.speed).exp()
class DerivBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, channel_dim: int,
threshold: 0.05, max_factor: 0.05,
epsilon: 1.0e-10) -> Tensor:
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
ctx.save_for_backward(factor)
ctx.epsilon = epsilon
ctx.sum_dims = sum_dims
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
factor, = ctx.saved_tensors
neg_delta_grad = x_grad.abs() * factor
if ctx.epsilon != 0.0:
sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True)
deriv_is_zero = (sum_abs_grad == 0.0)
neg_delta_grad += ctx.epsilon * deriv_is_zero
return x_grad - neg_delta_grad, None, None, None, None
class DerivBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
each channel, that it is positive at least a proportion `threshold` of the
time. It does this by multiplying negative derivative values by up to
(1+max_factor), and positive derivative values by up to (1-max_factor),
interpolated from 0 at the threshold to those extremal values when none
of the inputs are positive.
When all grads are zero for a channel, this
module sets all the input derivatives for that channel to -epsilon; the
idea is to bring completely dead neurons back to life this way.
"""
def __init__(self, channel_dim: int,
threshold: float = 0.05,
max_factor: float = 0.05,
epsilon: float = 1.0e-10):
super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim
self.threshold = threshold
self.max_factor = max_factor
self.epsilon = epsilon
def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
self.max_factor, self.epsilon)
def _test_exp_scale_swish(): def _test_exp_scale_swish():
class Swish(torch.nn.Module): class Swish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -271,5 +337,26 @@ def _test_exp_scale_swish():
def _test_deriv_balancer():
channel_dim = 0
probs = torch.arange(0, 1, 0.01)
N = 500
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10)
y_grad = torch.sign(torch.randn(probs.numel(), N))
y_grad[-1,:] = 0
y = m(x)
y.backward(gradient=y_grad)
print("x = ", x)
print("y grad = ", y_grad)
print("x grad = ", x.grad)
if __name__ == '__main__': if __name__ == '__main__':
_test_deriv_balancer()
_test_exp_scale_swish() _test_exp_scale_swish()

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, ExpScaleSwish from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -156,6 +156,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.02,
max_factor=0.02),
ExpScaleSwish(dim_feedforward, speed=20.0), ExpScaleSwish(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
@ -163,6 +165,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.02,
max_factor=0.02),
ExpScaleSwish(dim_feedforward, speed=20.0), ExpScaleSwish(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved