diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 89ec7b571..00da79cdc 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -676,6 +676,23 @@ class ReverseGrad(torch.autograd.Function): def backward(ctx, x_grad): return -x_grad + +class DebugGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, x, name): + ctx.save_for_backward(x) + ctx.name = name + return x + @staticmethod + def backward(ctx, x_grad): + x, = ctx.saved_tensors + x_grad_sum = x_grad.sum().to('cpu').item() + x_grad_x_sum = (x_grad * x).sum().to('cpu').item() + print(f"For {ctx.name}, x_grad_sum = {x_grad_sum}, x_grad_x_sum = {x_grad_x_sum}") + return x_grad, None + + + class DiscreteBottleneck(nn.Module): """ This layer forces its input through an information bottleneck via @@ -804,7 +821,7 @@ class DiscreteBottleneck(nn.Module): This is unnecessary if straight_through_scale == 1.0, since in that case it would not affect the backpropagated derivatives. """ - x = self.norm_in(x) * 5 + x = self.norm_in(x) * 5 # * 5 gives lower entropy.. starts training faster.. x = self.linear1(x) x = x + self.class_offsets @@ -815,7 +832,7 @@ class DiscreteBottleneck(nn.Module): # This is a little wasteful since we already compute the softmax # inside 'flow_sample'. softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None - if random.random() < 0.01: + if random.random() < 0.001: softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group) logsoftmax_temp = (softmax_temp + 1.0e-20).log() negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean() @@ -824,7 +841,7 @@ class DiscreteBottleneck(nn.Module): global_log_softmax = (global_softmax + 1.0e-20).log() global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean() - print("Entropy = ", -negentropy, ", averaged negentropy = ", global_negentropy) + print("Entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", -global_negentropy.to('cpu').item()) @@ -845,6 +862,10 @@ class DiscreteBottleneck(nn.Module): self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost embedding = self.norm_out(self.linear2(x)) + + #if random.random() < 0.01: + # return (embedding, DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax")) + #else: return (embedding, sampled, softmax) @@ -885,7 +906,7 @@ class DiscreteBottleneck(nn.Module): sampled = ReverseGrad.apply(sampled) if softmax is None: softmax = sampled - else: + elif reverse_grad: softmax = ReverseGrad.apply(softmax) logprobs = self.pred_linear(x) @@ -1825,11 +1846,11 @@ def _test_bidirectional_conformer(): def _test_discrete_bottleneck(): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') - dim = 128 + dim = 256 tot_classes = 256 - num_groups = 1 - interp_prob = 0.8 - straight_through_scale = 1.0 # will change + num_groups = 8 + interp_prob = 1.0 + straight_through_scale = 0.0 # will change need_softmax = True b = DiscreteBottleneck(dim, tot_classes, num_groups, @@ -1843,29 +1864,42 @@ def _test_discrete_bottleneck(): model.train() optim = torch.optim.Adam(params=model.parameters(), - lr=1.0e-03) + lr=3.0e-04) + scale = 0.7 # determines the feature correlation..should be between 0 and 1. + #https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2).. + # scale corresponds to rho^2, rho being sqrt(scale). + mutual_information = dim * -0.5 * math.log(1.0 - scale) + print("mutual_information = ", mutual_information) + for epoch in range(10): torch.save(model.state_dict(), f'epoch-{epoch}.pt') - for i in range(1000): + for i in range(2000): # TODO: also test padding_mask T = 300 N = 10 feats = torch.randn(T, N, dim, device=device) + + feats2 = (feats * scale ** 0.5) + ((1.0 - scale) ** 0.5 * torch.randn(T, N, dim, device=device)) + + #print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}") + bn_memory, sampled, softmax = b(feats) - predictor = from_feats_predictor(feats) # Could also use `bn_memory`, perhaps. But using - # predictor because it contains everything, will give - # us a bound on max information.. + # using feats2 instead of feats will limit the mutual information, + # to the MI between feats and feats2, which we computed and printed + # above as mutual_information. + predictor = from_feats_predictor(feats2) prob = b.compute_prob(predictor, sampled, softmax) if True: sampled_reversed = ReverseGrad.apply(sampled) predictor_reversed = self_predictor(sampled_reversed) - predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device), - predictor_reversed[:-1,:,:]), dim=0) + #predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device), + # predictor_reversed[:-1,:,:]), dim=0) + predictor_reversed_shifted = predictor_reversed self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax, reverse_grad=True)