mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Version that is successfully optimizing...
This commit is contained in:
parent
38081bc3e3
commit
461cb7da6d
@ -765,7 +765,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
# of the mask can be 1, not 0, saving compute..
|
# of the mask can be 1, not 0, saving compute..
|
||||||
d = tot_classes - self.classes_per_group
|
d = tot_classes - self.classes_per_group
|
||||||
c = self.classes_per_group
|
c = self.classes_per_group
|
||||||
self.pred_cross = nn.Parameter(torch.Tensor(d, d))
|
self.pred_cross = nn.Parameter(torch.zeros(d, d))
|
||||||
# If d == 4 and c == 2, the expression below has the following value
|
# If d == 4 and c == 2, the expression below has the following value
|
||||||
# (treat True as 1 and False as 0).
|
# (treat True as 1 and False as 0).
|
||||||
#tensor([[ True, True, False, False],
|
#tensor([[ True, True, False, False],
|
||||||
@ -804,7 +804,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
This is unnecessary if straight_through_scale == 1.0, since in that
|
This is unnecessary if straight_through_scale == 1.0, since in that
|
||||||
case it would not affect the backpropagated derivatives.
|
case it would not affect the backpropagated derivatives.
|
||||||
"""
|
"""
|
||||||
x = self.norm_in(x)
|
x = self.norm_in(x) * 5
|
||||||
x = self.linear1(x)
|
x = self.linear1(x)
|
||||||
x = x + self.class_offsets
|
x = x + self.class_offsets
|
||||||
|
|
||||||
@ -815,6 +815,18 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
# This is a little wasteful since we already compute the softmax
|
# This is a little wasteful since we already compute the softmax
|
||||||
# inside 'flow_sample'.
|
# inside 'flow_sample'.
|
||||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||||
|
if random.random() < 0.01:
|
||||||
|
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||||
|
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||||
|
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
|
||||||
|
|
||||||
|
global_softmax = softmax_temp.mean(dim=(0,1))
|
||||||
|
global_log_softmax = (global_softmax + 1.0e-20).log()
|
||||||
|
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
|
||||||
|
|
||||||
|
print("Entropy = ", -negentropy, ", averaged negentropy = ", global_negentropy)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
x = torch_flow_sampling.flow_sample(x,
|
x = torch_flow_sampling.flow_sample(x,
|
||||||
interp_prob=self.interp_prob,
|
interp_prob=self.interp_prob,
|
||||||
@ -825,7 +837,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
|
|
||||||
sampled = x
|
sampled = x
|
||||||
|
|
||||||
if self.training:
|
if self.training and False:
|
||||||
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
||||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||||
@ -1815,9 +1827,9 @@ def _test_discrete_bottleneck():
|
|||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
dim = 128
|
dim = 128
|
||||||
tot_classes = 256
|
tot_classes = 256
|
||||||
num_groups = 4
|
num_groups = 1
|
||||||
interp_prob = 0.8
|
interp_prob = 0.8
|
||||||
straight_through_scale = 0.0 # will change
|
straight_through_scale = 1.0 # will change
|
||||||
need_softmax = True
|
need_softmax = True
|
||||||
|
|
||||||
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
||||||
@ -1825,16 +1837,17 @@ def _test_discrete_bottleneck():
|
|||||||
|
|
||||||
self_predictor = nn.Linear(tot_classes, dim).to(device)
|
self_predictor = nn.Linear(tot_classes, dim).to(device)
|
||||||
|
|
||||||
|
from_feats_predictor = nn.Linear(dim, dim).to(device)
|
||||||
|
|
||||||
optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())),
|
model = nn.ModuleList([b, self_predictor, from_feats_predictor])
|
||||||
lr=1.0e-03, momentum=0.99)
|
model.train()
|
||||||
|
|
||||||
|
optim = torch.optim.Adam(params=model.parameters(),
|
||||||
|
lr=1.0e-03)
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(10):
|
for epoch in range(10):
|
||||||
state_dict = dict()
|
torch.save(model.state_dict(), f'epoch-{epoch}.pt')
|
||||||
state_dict['b'] = b.state_dict()
|
|
||||||
state_dict['s'] = self_predictor.state_dict()
|
|
||||||
torch.save(state_dict, f'epoch-{epoch}.pt')
|
|
||||||
for i in range(1000):
|
for i in range(1000):
|
||||||
# TODO: also test padding_mask
|
# TODO: also test padding_mask
|
||||||
T = 300
|
T = 300
|
||||||
@ -1843,21 +1856,22 @@ def _test_discrete_bottleneck():
|
|||||||
feats = torch.randn(T, N, dim, device=device)
|
feats = torch.randn(T, N, dim, device=device)
|
||||||
bn_memory, sampled, softmax = b(feats)
|
bn_memory, sampled, softmax = b(feats)
|
||||||
|
|
||||||
predictor = feats # Could also use `bn_memory`, perhaps. But using
|
predictor = from_feats_predictor(feats) # Could also use `bn_memory`, perhaps. But using
|
||||||
# predictor because it contains everything, will give
|
# predictor because it contains everything, will give
|
||||||
# us a bound on max information..
|
# us a bound on max information..
|
||||||
prob = b.compute_prob(feats, sampled, softmax)
|
prob = b.compute_prob(predictor, sampled, softmax)
|
||||||
|
|
||||||
sampled_reversed = ReverseGrad.apply(sampled)
|
if True:
|
||||||
predictor_reversed = self_predictor(sampled_reversed)
|
sampled_reversed = ReverseGrad.apply(sampled)
|
||||||
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
predictor_reversed = self_predictor(sampled_reversed)
|
||||||
predictor_reversed[:-1,:,:]), dim=0)
|
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||||
|
predictor_reversed[:-1,:,:]), dim=0)
|
||||||
|
|
||||||
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||||
reverse_grad=True)
|
reverse_grad=True)
|
||||||
|
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
|
||||||
|
|
||||||
normalized_prob = (prob / (T * N)).to('cpu').item()
|
normalized_prob = (prob / (T * N)).to('cpu').item()
|
||||||
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
|
|
||||||
|
|
||||||
loss = -(prob - self_prob)
|
loss = -(prob - self_prob)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user