diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index ea0204138..8d01d8fc0 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,14 +47,12 @@ class Conv2dSubsampling(nn.Module): ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -325,7 +323,8 @@ class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: float, # e.g. 0.05 + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 max_factor: float, # e.g. 0.01 min_abs: float, # e.g. 0.2 max_abs: float, # e.g. 1000.0 @@ -336,7 +335,13 @@ class DerivBalancerFunction(torch.autograd.Function): sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor = (threshold - proportion_positive).relu() * (max_factor / threshold) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) below_threshold = (mean_abs < min_abs) @@ -348,16 +353,14 @@ class DerivBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) neg_delta_grad = x_grad.abs() * (factor + scale_factor) - - - return x_grad - neg_delta_grad, None, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -516,7 +519,9 @@ class DerivBalancer(torch.nn.Module): Args: channel_dim: the dimension/axi corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - threshold: the threshold, per channel, of the proportion of the time + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_factor: the maximum factor by which we modify the derivatives, e.g. with max_factor=0.02, the the derivatives would be multiplied by @@ -538,19 +543,22 @@ class DerivBalancer(torch.nn.Module): out of floating point numerical range (especially in half precision). """ def __init__(self, channel_dim: int, - threshold: float = 0.05, + min_positive: float = 0.05, + max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 1000.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim - self.threshold = threshold + self.min_positive = min_positive + self.max_positive = max_positive self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, + return DerivBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, self.max_factor, self.min_abs, self.max_abs) @@ -600,14 +608,14 @@ def _test_exp_scale_relu(): def _test_deriv_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) - N = 500 + N = 1000 x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) - y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) @@ -618,14 +626,16 @@ def _test_deriv_balancer_sign(): def _test_deriv_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) - N = 500 - x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1)) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + m = DerivBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 9dd6bae4d..3516c2205 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -158,8 +158,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=-1), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -167,8 +166,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=-1), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),