mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Changes to test, RE shifting..
This commit is contained in:
parent
da3c9c7594
commit
0f29f35a42
@ -1867,7 +1867,7 @@ def _test_discrete_bottleneck():
|
|||||||
lr=3.0e-04)
|
lr=3.0e-04)
|
||||||
|
|
||||||
|
|
||||||
scale = 0.7 # determines the feature correlation..should be between 0 and 1.
|
scale = 0.3 # determines the feature correlation..should be between 0 and 1.
|
||||||
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
|
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
|
||||||
# scale corresponds to rho^2, rho being sqrt(scale).
|
# scale corresponds to rho^2, rho being sqrt(scale).
|
||||||
mutual_information = dim * -0.5 * math.log(1.0 - scale)
|
mutual_information = dim * -0.5 * math.log(1.0 - scale)
|
||||||
@ -1897,9 +1897,13 @@ def _test_discrete_bottleneck():
|
|||||||
if True:
|
if True:
|
||||||
sampled_reversed = ReverseGrad.apply(sampled)
|
sampled_reversed = ReverseGrad.apply(sampled)
|
||||||
predictor_reversed = self_predictor(sampled_reversed)
|
predictor_reversed = self_predictor(sampled_reversed)
|
||||||
#predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
|
||||||
# predictor_reversed[:-1,:,:]), dim=0)
|
if True:
|
||||||
predictor_reversed_shifted = predictor_reversed
|
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||||
|
predictor_reversed[:-1,:,:]), dim=0)
|
||||||
|
else:
|
||||||
|
# skip shifting.. want to see the effect..
|
||||||
|
predictor_reversed_shifted = predictor_reversed
|
||||||
|
|
||||||
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||||
reverse_grad=True)
|
reverse_grad=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user