Changes to test, RE shifting..

This commit is contained in:
Daniel Povey 2021-09-18 23:04:50 +08:00
parent da3c9c7594
commit 0f29f35a42

View File

@ -1867,7 +1867,7 @@ def _test_discrete_bottleneck():
lr=3.0e-04) lr=3.0e-04)
scale = 0.7 # determines the feature correlation..should be between 0 and 1. scale = 0.3 # determines the feature correlation..should be between 0 and 1.
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2).. #https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
# scale corresponds to rho^2, rho being sqrt(scale). # scale corresponds to rho^2, rho being sqrt(scale).
mutual_information = dim * -0.5 * math.log(1.0 - scale) mutual_information = dim * -0.5 * math.log(1.0 - scale)
@ -1897,9 +1897,13 @@ def _test_discrete_bottleneck():
if True: if True:
sampled_reversed = ReverseGrad.apply(sampled) sampled_reversed = ReverseGrad.apply(sampled)
predictor_reversed = self_predictor(sampled_reversed) predictor_reversed = self_predictor(sampled_reversed)
#predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
# predictor_reversed[:-1,:,:]), dim=0) if True:
predictor_reversed_shifted = predictor_reversed predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
predictor_reversed[:-1,:,:]), dim=0)
else:
# skip shifting.. want to see the effect..
predictor_reversed_shifted = predictor_reversed
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax, self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
reverse_grad=True) reverse_grad=True)