diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index e38a94d09..aa842a31f 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,11 +47,15 @@ class Conv2dSubsampling(nn.Module): nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), + DerivBalancer(channel_dim=1, threshold=0.02, + max_factor=0.02), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), + DerivBalancer(channel_dim=1, threshold=0.02, + max_factor=0.02), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), ) @@ -248,6 +252,68 @@ class ExpScaleSwish(torch.nn.Module): # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() # return x * (self.scale * self.speed).exp() + + + +class DerivBalancerFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, channel_dim: int, + threshold: 0.05, max_factor: 0.05, + epsilon: 1.0e-10) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True) + factor = (threshold - proportion_positive).relu() * (max_factor / threshold) + + ctx.save_for_backward(factor) + ctx.epsilon = epsilon + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + factor, = ctx.saved_tensors + neg_delta_grad = x_grad.abs() * factor + if ctx.epsilon != 0.0: + sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True) + deriv_is_zero = (sum_abs_grad == 0.0) + neg_delta_grad += ctx.epsilon * deriv_is_zero + + return x_grad - neg_delta_grad, None, None, None, None + + + +class DerivBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 0 at the threshold to those extremal values when none + of the inputs are positive. + + When all grads are zero for a channel, this + module sets all the input derivatives for that channel to -epsilon; the + idea is to bring completely dead neurons back to life this way. + """ + def __init__(self, channel_dim: int, + threshold: float = 0.05, + max_factor: float = 0.05, + epsilon: float = 1.0e-10): + super(DerivBalancer, self).__init__() + self.channel_dim = channel_dim + self.threshold = threshold + self.max_factor = max_factor + self.epsilon = epsilon + + def forward(self, x: Tensor) -> Tensor: + return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, + self.max_factor, self.epsilon) + + + def _test_exp_scale_swish(): class Swish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: @@ -271,5 +337,26 @@ def _test_exp_scale_swish(): +def _test_deriv_balancer(): + channel_dim = 0 + probs = torch.arange(0, 1, 0.01) + N = 500 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + y_grad[-1,:] = 0 + + y = m(x) + y.backward(gradient=y_grad) + print("x = ", x) + print("y grad = ", y_grad) + print("x grad = ", x.grad) + + + if __name__ == '__main__': + _test_deriv_balancer() _test_exp_scale_swish() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 368165008..056958ff6 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish +from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer import torch from torch import Tensor, nn @@ -156,6 +156,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), + DerivBalancer(channel_dim=-1, threshold=0.02, + max_factor=0.02), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -163,6 +165,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), + DerivBalancer(channel_dim=-1, threshold=0.02, + max_factor=0.02), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 973733d4b..6d6b3f240 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved