diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 600156bf1..a66421adf 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -220,7 +220,7 @@ def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: class ExpScaleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x, scale) + ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed return _exp_scale_swish(x, scale, speed)