mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 12:32:20 +00:00
Get backward working
This commit is contained in:
parent
058fff0365
commit
a20d490332
@ -668,8 +668,10 @@ class SimpleCausalEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class ReverseGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad):
|
||||
return -x_grad
|
||||
|
||||
@ -1770,6 +1772,7 @@ def _test_bidirectional_conformer():
|
||||
N = 10
|
||||
C = num_features
|
||||
feats = torch.randn(N, T, C)
|
||||
feats.requires_grad = True
|
||||
|
||||
tokens = _gen_rand_tokens(N)
|
||||
supervision = _gen_supervision(tokens)
|
||||
@ -1802,6 +1805,9 @@ def _test_bidirectional_conformer():
|
||||
|
||||
print("self prediction logprob = ", self_prediction_logprob)
|
||||
|
||||
loss = -(decoder_logprob + reverse_decoder_logprob - self_prediction_logprob)
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_bidirectional_conformer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user