Some updates to tests, still figuring out issues..

This commit is contained in:
Daniel Povey 2021-09-18 21:47:31 +08:00
parent 461cb7da6d
commit da3c9c7594

View File

@ -676,6 +676,23 @@ class ReverseGrad(torch.autograd.Function):
def backward(ctx, x_grad):
return -x_grad
class DebugGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x, name):
ctx.save_for_backward(x)
ctx.name = name
return x
@staticmethod
def backward(ctx, x_grad):
x, = ctx.saved_tensors
x_grad_sum = x_grad.sum().to('cpu').item()
x_grad_x_sum = (x_grad * x).sum().to('cpu').item()
print(f"For {ctx.name}, x_grad_sum = {x_grad_sum}, x_grad_x_sum = {x_grad_x_sum}")
return x_grad, None
class DiscreteBottleneck(nn.Module):
"""
This layer forces its input through an information bottleneck via
@ -804,7 +821,7 @@ class DiscreteBottleneck(nn.Module):
This is unnecessary if straight_through_scale == 1.0, since in that
case it would not affect the backpropagated derivatives.
"""
x = self.norm_in(x) * 5
x = self.norm_in(x) * 5 # * 5 gives lower entropy.. starts training faster..
x = self.linear1(x)
x = x + self.class_offsets
@ -815,7 +832,7 @@ class DiscreteBottleneck(nn.Module):
# This is a little wasteful since we already compute the softmax
# inside 'flow_sample'.
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
if random.random() < 0.01:
if random.random() < 0.001:
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
@ -824,7 +841,7 @@ class DiscreteBottleneck(nn.Module):
global_log_softmax = (global_softmax + 1.0e-20).log()
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
print("Entropy = ", -negentropy, ", averaged negentropy = ", global_negentropy)
print("Entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", -global_negentropy.to('cpu').item())
@ -845,6 +862,10 @@ class DiscreteBottleneck(nn.Module):
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
embedding = self.norm_out(self.linear2(x))
#if random.random() < 0.01:
# return (embedding, DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"))
#else:
return (embedding, sampled, softmax)
@ -885,7 +906,7 @@ class DiscreteBottleneck(nn.Module):
sampled = ReverseGrad.apply(sampled)
if softmax is None:
softmax = sampled
else:
elif reverse_grad:
softmax = ReverseGrad.apply(softmax)
logprobs = self.pred_linear(x)
@ -1825,11 +1846,11 @@ def _test_bidirectional_conformer():
def _test_discrete_bottleneck():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dim = 128
dim = 256
tot_classes = 256
num_groups = 1
interp_prob = 0.8
straight_through_scale = 1.0 # will change
num_groups = 8
interp_prob = 1.0
straight_through_scale = 0.0 # will change
need_softmax = True
b = DiscreteBottleneck(dim, tot_classes, num_groups,
@ -1843,29 +1864,42 @@ def _test_discrete_bottleneck():
model.train()
optim = torch.optim.Adam(params=model.parameters(),
lr=1.0e-03)
lr=3.0e-04)
scale = 0.7 # determines the feature correlation..should be between 0 and 1.
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
# scale corresponds to rho^2, rho being sqrt(scale).
mutual_information = dim * -0.5 * math.log(1.0 - scale)
print("mutual_information = ", mutual_information)
for epoch in range(10):
torch.save(model.state_dict(), f'epoch-{epoch}.pt')
for i in range(1000):
for i in range(2000):
# TODO: also test padding_mask
T = 300
N = 10
feats = torch.randn(T, N, dim, device=device)
feats2 = (feats * scale ** 0.5) + ((1.0 - scale) ** 0.5 * torch.randn(T, N, dim, device=device))
#print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}")
bn_memory, sampled, softmax = b(feats)
predictor = from_feats_predictor(feats) # Could also use `bn_memory`, perhaps. But using
# predictor because it contains everything, will give
# us a bound on max information..
# using feats2 instead of feats will limit the mutual information,
# to the MI between feats and feats2, which we computed and printed
# above as mutual_information.
predictor = from_feats_predictor(feats2)
prob = b.compute_prob(predictor, sampled, softmax)
if True:
sampled_reversed = ReverseGrad.apply(sampled)
predictor_reversed = self_predictor(sampled_reversed)
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
predictor_reversed[:-1,:,:]), dim=0)
#predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
# predictor_reversed[:-1,:,:]), dim=0)
predictor_reversed_shifted = predictor_reversed
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
reverse_grad=True)