Get backward working

This commit is contained in:
Daniel Povey 2021-09-18 12:36:50 +08:00
parent 058fff0365
commit a20d490332

View File

@ -668,8 +668,10 @@ class SimpleCausalEncoderLayer(nn.Module):
class ReverseGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, x_grad):
return -x_grad
@ -1770,6 +1772,7 @@ def _test_bidirectional_conformer():
N = 10
C = num_features
feats = torch.randn(N, T, C)
feats.requires_grad = True
tokens = _gen_rand_tokens(N)
supervision = _gen_supervision(tokens)
@ -1802,6 +1805,9 @@ def _test_bidirectional_conformer():
print("self prediction logprob = ", self_prediction_logprob)
loss = -(decoder_logprob + reverse_decoder_logprob - self_prediction_logprob)
loss.backward()
if __name__ == '__main__':
_test_bidirectional_conformer()