diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 390d31115..d1ff7f233 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -312,33 +312,36 @@ class ExpScaleRelu(torch.nn.Module): class DerivBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, channel_dim: int, + def forward(ctx, x: Tensor, + channel_dim: int, threshold: float = 0.05, max_factor: float = 0.05, - zero: float = 0.02, - epsilon: float = 1.0e-10) -> Tensor: + min_abs: float = 0.2) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True) + xgt0 = x > 0 + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) - ctx.save_for_backward(factor) - ctx.epsilon = epsilon + below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs) + + ctx.save_for_backward(factor, xgt0, below_threshold) + ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - factor, = ctx.saved_tensors - neg_delta_grad = x_grad.abs() * factor - if ctx.epsilon != 0.0: - sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True) - deriv_is_zero = (sum_abs_grad == 0.0) - neg_delta_grad += ctx.epsilon * deriv_is_zero + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + factor, xgt0, below_threshold = ctx.saved_tensors + dtype = x_grad.dtype + too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0) - return x_grad - neg_delta_grad, None, None, None, None, None + neg_delta_grad = x_grad.abs() * (factor + too_small_factor) + + + return x_grad - neg_delta_grad, None, None, None, None class BasicNorm(torch.nn.Module): @@ -449,19 +452,17 @@ class DerivBalancer(torch.nn.Module): def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.02, - zero: float = 0.02, - epsilon: float = 1.0e-10): + min_abs: float = 0.2): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor - self.zero = zero - self.epsilon = epsilon + self.min_abs = min_abs def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.zero, - self.epsilon) + self.max_factor, self.min_abs) + @@ -505,23 +506,41 @@ def _test_exp_scale_relu(): -def _test_deriv_balancer(): +def _test_deriv_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) N = 500 x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10) + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) - print("x = ", x) - print("y grad = ", y_grad) - print("x grad = ", x.grad) + print("_test_deriv_balancer_sign: x = ", x) + print("_test_deriv_balancer_sign: y grad = ", y_grad) + print("_test_deriv_balancer_sign: x grad = ", x.grad) + +def _test_deriv_balancer_magnitude(): + channel_dim = 0 + magnitudes = torch.arange(0, 1, 0.01) + N = 500 + x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + y_grad[-1,:] = 0 + + y = m(x) + y.backward(gradient=y_grad) + print("_test_deriv_balancer_magnitude: x = ", x) + print("_test_deriv_balancer_magnitude: y grad = ", y_grad) + print("_test_deriv_balancer_magnitude: x grad = ", x.grad) def _test_basic_norm(): @@ -543,7 +562,8 @@ def _test_basic_norm(): if __name__ == '__main__': - _test_deriv_balancer() + _test_deriv_balancer_sign() + _test_deriv_balancer_magnitude() _test_exp_scale_swish() _test_exp_scale_relu() _test_basic_norm() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 36a1ae869..618d90490 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved