mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 12:32:20 +00:00
Get backward working
This commit is contained in:
parent
058fff0365
commit
a20d490332
@ -668,8 +668,10 @@ class SimpleCausalEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ReverseGrad(torch.autograd.Function):
|
class ReverseGrad(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
return x
|
return x
|
||||||
|
@staticmethod
|
||||||
def backward(ctx, x_grad):
|
def backward(ctx, x_grad):
|
||||||
return -x_grad
|
return -x_grad
|
||||||
|
|
||||||
@ -1770,6 +1772,7 @@ def _test_bidirectional_conformer():
|
|||||||
N = 10
|
N = 10
|
||||||
C = num_features
|
C = num_features
|
||||||
feats = torch.randn(N, T, C)
|
feats = torch.randn(N, T, C)
|
||||||
|
feats.requires_grad = True
|
||||||
|
|
||||||
tokens = _gen_rand_tokens(N)
|
tokens = _gen_rand_tokens(N)
|
||||||
supervision = _gen_supervision(tokens)
|
supervision = _gen_supervision(tokens)
|
||||||
@ -1802,6 +1805,9 @@ def _test_bidirectional_conformer():
|
|||||||
|
|
||||||
print("self prediction logprob = ", self_prediction_logprob)
|
print("self prediction logprob = ", self_prediction_logprob)
|
||||||
|
|
||||||
|
loss = -(decoder_logprob + reverse_decoder_logprob - self_prediction_logprob)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_bidirectional_conformer()
|
_test_bidirectional_conformer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user