diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 29621bf52..390d31115 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -219,7 +219,7 @@ def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: x = x * (scale * speed).exp() return x -class ExpScaleSwishFunction(torch.autograd.Function): +class SwishExpScaleFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) @@ -237,16 +237,16 @@ class ExpScaleSwishFunction(torch.autograd.Function): return x.grad, scale.grad, None -class ExpScaleSwish(torch.nn.Module): - # combines ExpScale an Swish - # caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0) +class SwishExpScale(torch.nn.Module): + # combines ExpScale and a Swish (actually the ExpScale is after the Swish). + # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) def __init__(self, *shape, speed: float = 1.0): - super(ExpScaleSwish, self).__init__() + super(SwishExpScale, self).__init__() self.scale = nn.Parameter(torch.zeros(*shape)) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return ExpScaleSwishFunction.apply(x, self.scale, self.speed) + return SwishExpScaleFunction.apply(x, self.scale, self.speed) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() @@ -313,13 +313,15 @@ class ExpScaleRelu(torch.nn.Module): class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: 0.05, max_factor: 0.05, - epsilon: 1.0e-10) -> Tensor: + threshold: float = 0.05, + max_factor: float = 0.05, + zero: float = 0.02, + epsilon: float = 1.0e-10) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True) + proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) ctx.save_for_backward(factor) @@ -328,7 +330,7 @@ class DerivBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: factor, = ctx.saved_tensors neg_delta_grad = x_grad.abs() * factor if ctx.epsilon != 0.0: @@ -336,7 +338,7 @@ class DerivBalancerFunction(torch.autograd.Function): deriv_is_zero = (sum_abs_grad == 0.0) neg_delta_grad += ctx.epsilon * deriv_is_zero - return x_grad - neg_delta_grad, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -429,20 +431,37 @@ class DerivBalancer(torch.nn.Module): When all grads are zero for a channel, this module sets all the input derivatives for that channel to -epsilon; the idea is to bring completely dead neurons back to life this way. + + Args: + channel_dim: the dimension/axi corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + threshold: the threshold, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives, + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.01]. + zero: we use this value in the comparison (x > 0), i.e. we actually use + (x > zero). The reason for using a threshold slightly greater + than zero is that it will tend to prevent situations where the + inputs shrink close to zero and the nonlinearity (e.g. swish) + behaves like a linear function and we learn nothing. """ def __init__(self, channel_dim: int, threshold: float = 0.05, - max_factor: float = 0.05, + max_factor: float = 0.02, + zero: float = 0.02, epsilon: float = 1.0e-10): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor + self.zero = zero self.epsilon = epsilon def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.epsilon) + self.max_factor, self.zero, + self.epsilon) @@ -455,7 +474,7 @@ def _test_exp_scale_swish(): x1 = torch.randn(50, 60).detach() x2 = x1.detach() - m1 = ExpScaleSwish(50, 1, speed=4.0) + m1 = SwishExpScale(50, 1, speed=4.0) m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) x1.requires_grad = True x2.requires_grad = True diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 4cf66e2fe..7a7a09c27 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm +from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm import torch from torch import Tensor, nn @@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c355c7ad3..36a1ae869 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved