mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Version that is successfully optimizing...
This commit is contained in:
parent
38081bc3e3
commit
461cb7da6d
@ -765,7 +765,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
# of the mask can be 1, not 0, saving compute..
|
||||
d = tot_classes - self.classes_per_group
|
||||
c = self.classes_per_group
|
||||
self.pred_cross = nn.Parameter(torch.Tensor(d, d))
|
||||
self.pred_cross = nn.Parameter(torch.zeros(d, d))
|
||||
# If d == 4 and c == 2, the expression below has the following value
|
||||
# (treat True as 1 and False as 0).
|
||||
#tensor([[ True, True, False, False],
|
||||
@ -804,7 +804,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
This is unnecessary if straight_through_scale == 1.0, since in that
|
||||
case it would not affect the backpropagated derivatives.
|
||||
"""
|
||||
x = self.norm_in(x)
|
||||
x = self.norm_in(x) * 5
|
||||
x = self.linear1(x)
|
||||
x = x + self.class_offsets
|
||||
|
||||
@ -815,6 +815,18 @@ class DiscreteBottleneck(nn.Module):
|
||||
# This is a little wasteful since we already compute the softmax
|
||||
# inside 'flow_sample'.
|
||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||
if random.random() < 0.01:
|
||||
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
|
||||
|
||||
global_softmax = softmax_temp.mean(dim=(0,1))
|
||||
global_log_softmax = (global_softmax + 1.0e-20).log()
|
||||
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
|
||||
|
||||
print("Entropy = ", -negentropy, ", averaged negentropy = ", global_negentropy)
|
||||
|
||||
|
||||
|
||||
x = torch_flow_sampling.flow_sample(x,
|
||||
interp_prob=self.interp_prob,
|
||||
@ -825,7 +837,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
|
||||
sampled = x
|
||||
|
||||
if self.training:
|
||||
if self.training and False:
|
||||
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||
@ -1815,9 +1827,9 @@ def _test_discrete_bottleneck():
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
dim = 128
|
||||
tot_classes = 256
|
||||
num_groups = 4
|
||||
num_groups = 1
|
||||
interp_prob = 0.8
|
||||
straight_through_scale = 0.0 # will change
|
||||
straight_through_scale = 1.0 # will change
|
||||
need_softmax = True
|
||||
|
||||
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
||||
@ -1825,16 +1837,17 @@ def _test_discrete_bottleneck():
|
||||
|
||||
self_predictor = nn.Linear(tot_classes, dim).to(device)
|
||||
|
||||
from_feats_predictor = nn.Linear(dim, dim).to(device)
|
||||
|
||||
optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())),
|
||||
lr=1.0e-03, momentum=0.99)
|
||||
model = nn.ModuleList([b, self_predictor, from_feats_predictor])
|
||||
model.train()
|
||||
|
||||
optim = torch.optim.Adam(params=model.parameters(),
|
||||
lr=1.0e-03)
|
||||
|
||||
|
||||
for epoch in range(10):
|
||||
state_dict = dict()
|
||||
state_dict['b'] = b.state_dict()
|
||||
state_dict['s'] = self_predictor.state_dict()
|
||||
torch.save(state_dict, f'epoch-{epoch}.pt')
|
||||
torch.save(model.state_dict(), f'epoch-{epoch}.pt')
|
||||
for i in range(1000):
|
||||
# TODO: also test padding_mask
|
||||
T = 300
|
||||
@ -1843,21 +1856,22 @@ def _test_discrete_bottleneck():
|
||||
feats = torch.randn(T, N, dim, device=device)
|
||||
bn_memory, sampled, softmax = b(feats)
|
||||
|
||||
predictor = feats # Could also use `bn_memory`, perhaps. But using
|
||||
predictor = from_feats_predictor(feats) # Could also use `bn_memory`, perhaps. But using
|
||||
# predictor because it contains everything, will give
|
||||
# us a bound on max information..
|
||||
prob = b.compute_prob(feats, sampled, softmax)
|
||||
prob = b.compute_prob(predictor, sampled, softmax)
|
||||
|
||||
sampled_reversed = ReverseGrad.apply(sampled)
|
||||
predictor_reversed = self_predictor(sampled_reversed)
|
||||
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||
predictor_reversed[:-1,:,:]), dim=0)
|
||||
if True:
|
||||
sampled_reversed = ReverseGrad.apply(sampled)
|
||||
predictor_reversed = self_predictor(sampled_reversed)
|
||||
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||
predictor_reversed[:-1,:,:]), dim=0)
|
||||
|
||||
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||
reverse_grad=True)
|
||||
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||
reverse_grad=True)
|
||||
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
|
||||
|
||||
normalized_prob = (prob / (T * N)).to('cpu').item()
|
||||
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
|
||||
|
||||
loss = -(prob - self_prob)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user