Version that is successfully optimizing...

This commit is contained in:
Daniel Povey 2021-09-18 16:40:55 +08:00
parent 38081bc3e3
commit 461cb7da6d

View File

@ -765,7 +765,7 @@ class DiscreteBottleneck(nn.Module):
# of the mask can be 1, not 0, saving compute.. # of the mask can be 1, not 0, saving compute..
d = tot_classes - self.classes_per_group d = tot_classes - self.classes_per_group
c = self.classes_per_group c = self.classes_per_group
self.pred_cross = nn.Parameter(torch.Tensor(d, d)) self.pred_cross = nn.Parameter(torch.zeros(d, d))
# If d == 4 and c == 2, the expression below has the following value # If d == 4 and c == 2, the expression below has the following value
# (treat True as 1 and False as 0). # (treat True as 1 and False as 0).
#tensor([[ True, True, False, False], #tensor([[ True, True, False, False],
@ -804,7 +804,7 @@ class DiscreteBottleneck(nn.Module):
This is unnecessary if straight_through_scale == 1.0, since in that This is unnecessary if straight_through_scale == 1.0, since in that
case it would not affect the backpropagated derivatives. case it would not affect the backpropagated derivatives.
""" """
x = self.norm_in(x) x = self.norm_in(x) * 5
x = self.linear1(x) x = self.linear1(x)
x = x + self.class_offsets x = x + self.class_offsets
@ -815,6 +815,18 @@ class DiscreteBottleneck(nn.Module):
# This is a little wasteful since we already compute the softmax # This is a little wasteful since we already compute the softmax
# inside 'flow_sample'. # inside 'flow_sample'.
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
if random.random() < 0.01:
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
global_softmax = softmax_temp.mean(dim=(0,1))
global_log_softmax = (global_softmax + 1.0e-20).log()
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
print("Entropy = ", -negentropy, ", averaged negentropy = ", global_negentropy)
x = torch_flow_sampling.flow_sample(x, x = torch_flow_sampling.flow_sample(x,
interp_prob=self.interp_prob, interp_prob=self.interp_prob,
@ -825,7 +837,7 @@ class DiscreteBottleneck(nn.Module):
sampled = x sampled = x
if self.training: if self.training and False:
mean_class_probs = torch.mean(x.detach(), dim=(0,1)) mean_class_probs = torch.mean(x.detach(), dim=(0,1))
self.class_probs = (self.class_probs * self.class_probs_decay + self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay)) mean_class_probs * (1.0 - self.class_probs_decay))
@ -1815,9 +1827,9 @@ def _test_discrete_bottleneck():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dim = 128 dim = 128
tot_classes = 256 tot_classes = 256
num_groups = 4 num_groups = 1
interp_prob = 0.8 interp_prob = 0.8
straight_through_scale = 0.0 # will change straight_through_scale = 1.0 # will change
need_softmax = True need_softmax = True
b = DiscreteBottleneck(dim, tot_classes, num_groups, b = DiscreteBottleneck(dim, tot_classes, num_groups,
@ -1825,16 +1837,17 @@ def _test_discrete_bottleneck():
self_predictor = nn.Linear(tot_classes, dim).to(device) self_predictor = nn.Linear(tot_classes, dim).to(device)
from_feats_predictor = nn.Linear(dim, dim).to(device)
optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())), model = nn.ModuleList([b, self_predictor, from_feats_predictor])
lr=1.0e-03, momentum=0.99) model.train()
optim = torch.optim.Adam(params=model.parameters(),
lr=1.0e-03)
for epoch in range(10): for epoch in range(10):
state_dict = dict() torch.save(model.state_dict(), f'epoch-{epoch}.pt')
state_dict['b'] = b.state_dict()
state_dict['s'] = self_predictor.state_dict()
torch.save(state_dict, f'epoch-{epoch}.pt')
for i in range(1000): for i in range(1000):
# TODO: also test padding_mask # TODO: also test padding_mask
T = 300 T = 300
@ -1843,21 +1856,22 @@ def _test_discrete_bottleneck():
feats = torch.randn(T, N, dim, device=device) feats = torch.randn(T, N, dim, device=device)
bn_memory, sampled, softmax = b(feats) bn_memory, sampled, softmax = b(feats)
predictor = feats # Could also use `bn_memory`, perhaps. But using predictor = from_feats_predictor(feats) # Could also use `bn_memory`, perhaps. But using
# predictor because it contains everything, will give # predictor because it contains everything, will give
# us a bound on max information.. # us a bound on max information..
prob = b.compute_prob(feats, sampled, softmax) prob = b.compute_prob(predictor, sampled, softmax)
sampled_reversed = ReverseGrad.apply(sampled) if True:
predictor_reversed = self_predictor(sampled_reversed) sampled_reversed = ReverseGrad.apply(sampled)
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device), predictor_reversed = self_predictor(sampled_reversed)
predictor_reversed[:-1,:,:]), dim=0) predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
predictor_reversed[:-1,:,:]), dim=0)
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax, self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
reverse_grad=True) reverse_grad=True)
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
normalized_prob = (prob / (T * N)).to('cpu').item() normalized_prob = (prob / (T * N)).to('cpu').item()
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
loss = -(prob - self_prob) loss = -(prob - self_prob)