diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 971fc578a..4fb8bef7f 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -668,8 +668,10 @@ class SimpleCausalEncoderLayer(nn.Module): class ReverseGrad(torch.autograd.Function): + @staticmethod def forward(ctx, x): return x + @staticmethod def backward(ctx, x_grad): return -x_grad @@ -1770,6 +1772,7 @@ def _test_bidirectional_conformer(): N = 10 C = num_features feats = torch.randn(N, T, C) + feats.requires_grad = True tokens = _gen_rand_tokens(N) supervision = _gen_supervision(tokens) @@ -1802,6 +1805,9 @@ def _test_bidirectional_conformer(): print("self prediction logprob = ", self_prediction_logprob) + loss = -(decoder_logprob + reverse_decoder_logprob - self_prediction_logprob) + loss.backward() + if __name__ == '__main__': _test_bidirectional_conformer()