diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 8d0eef40c..89ec7b571 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -765,7 +765,7 @@ class DiscreteBottleneck(nn.Module): # of the mask can be 1, not 0, saving compute.. d = tot_classes - self.classes_per_group c = self.classes_per_group - self.pred_cross = nn.Parameter(torch.Tensor(d, d)) + self.pred_cross = nn.Parameter(torch.zeros(d, d)) # If d == 4 and c == 2, the expression below has the following value # (treat True as 1 and False as 0). #tensor([[ True, True, False, False], @@ -804,7 +804,7 @@ class DiscreteBottleneck(nn.Module): This is unnecessary if straight_through_scale == 1.0, since in that case it would not affect the backpropagated derivatives. """ - x = self.norm_in(x) + x = self.norm_in(x) * 5 x = self.linear1(x) x = x + self.class_offsets @@ -815,6 +815,18 @@ class DiscreteBottleneck(nn.Module): # This is a little wasteful since we already compute the softmax # inside 'flow_sample'. softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None + if random.random() < 0.01: + softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group) + logsoftmax_temp = (softmax_temp + 1.0e-20).log() + negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean() + + global_softmax = softmax_temp.mean(dim=(0,1)) + global_log_softmax = (global_softmax + 1.0e-20).log() + global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean() + + print("Entropy = ", -negentropy, ", averaged negentropy = ", global_negentropy) + + x = torch_flow_sampling.flow_sample(x, interp_prob=self.interp_prob, @@ -825,7 +837,7 @@ class DiscreteBottleneck(nn.Module): sampled = x - if self.training: + if self.training and False: mean_class_probs = torch.mean(x.detach(), dim=(0,1)) self.class_probs = (self.class_probs * self.class_probs_decay + mean_class_probs * (1.0 - self.class_probs_decay)) @@ -1815,9 +1827,9 @@ def _test_discrete_bottleneck(): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') dim = 128 tot_classes = 256 - num_groups = 4 + num_groups = 1 interp_prob = 0.8 - straight_through_scale = 0.0 # will change + straight_through_scale = 1.0 # will change need_softmax = True b = DiscreteBottleneck(dim, tot_classes, num_groups, @@ -1825,16 +1837,17 @@ def _test_discrete_bottleneck(): self_predictor = nn.Linear(tot_classes, dim).to(device) + from_feats_predictor = nn.Linear(dim, dim).to(device) - optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())), - lr=1.0e-03, momentum=0.99) + model = nn.ModuleList([b, self_predictor, from_feats_predictor]) + model.train() + + optim = torch.optim.Adam(params=model.parameters(), + lr=1.0e-03) for epoch in range(10): - state_dict = dict() - state_dict['b'] = b.state_dict() - state_dict['s'] = self_predictor.state_dict() - torch.save(state_dict, f'epoch-{epoch}.pt') + torch.save(model.state_dict(), f'epoch-{epoch}.pt') for i in range(1000): # TODO: also test padding_mask T = 300 @@ -1843,21 +1856,22 @@ def _test_discrete_bottleneck(): feats = torch.randn(T, N, dim, device=device) bn_memory, sampled, softmax = b(feats) - predictor = feats # Could also use `bn_memory`, perhaps. But using + predictor = from_feats_predictor(feats) # Could also use `bn_memory`, perhaps. But using # predictor because it contains everything, will give # us a bound on max information.. - prob = b.compute_prob(feats, sampled, softmax) + prob = b.compute_prob(predictor, sampled, softmax) - sampled_reversed = ReverseGrad.apply(sampled) - predictor_reversed = self_predictor(sampled_reversed) - predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device), - predictor_reversed[:-1,:,:]), dim=0) + if True: + sampled_reversed = ReverseGrad.apply(sampled) + predictor_reversed = self_predictor(sampled_reversed) + predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device), + predictor_reversed[:-1,:,:]), dim=0) - self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax, - reverse_grad=True) + self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax, + reverse_grad=True) + normalized_self_prob = (self_prob / (T * N)).to('cpu').item() normalized_prob = (prob / (T * N)).to('cpu').item() - normalized_self_prob = (self_prob / (T * N)).to('cpu').item() loss = -(prob - self_prob)