diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 4fb8bef7f..8d0eef40c 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -320,7 +320,8 @@ class BidirectionalConformer(nn.Module): if num_self_predictor_layers > 0: encoder_layer = SimpleCausalEncoderLayer(d_model, dropout=dropout) - self.self_predictor_encoder = encoder_layer + self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer) + for _ in range(num_self_predictor_layers)]) self.discrete_bottleneck = DiscreteBottleneck( @@ -478,11 +479,11 @@ class BidirectionalConformer(nn.Module): def self_prediction_forward( self, - memory_shifted: torch.Tensor, + bn_memory_shifted: torch.Tensor, memory_key_padding_mask: torch.Tensor, sampled: torch.Tensor, softmax: Optional[torch.Tensor], - reverse_gradient: bool = True) -> Tensor: + reverse_grad: bool = True) -> Tensor: """ Returns the total log-prob of the the labels sampled in the discrete bottleneck layer, as predicted using a relatively simple model that @@ -490,11 +491,11 @@ class BidirectionalConformer(nn.Module): [Appears on the denominator of an expression for mutual information]. Args: - memory_shifted: - It's the output of forward(), with shape [T, N, E], shifted - by one so that shifted_memory[t] == memory[t-1], as in: - (T, N, E) = memory.shape - memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) + bn_memory_shifted: + It's the bn_memory output of forward(), with shape [T, N, E], shifted + by one so that bn_shifted_memory[t] == bn_memory[t-1], as in: + (T, N, E) = bn_memory.shape + bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0) memory_key_padding_mask: The padding mask from the encoder, of shape [N, T], boolean, True for masked locations. @@ -503,41 +504,41 @@ class BidirectionalConformer(nn.Module): as given to the constructor. This will be needed for the 'reverse' model. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. - reverse_gradient: will likely be true. If true, the gradient is reversed twice + reverse_grad: will likely be true. If true, the gradient is reversed twice in this computation, so that we train predictors with the correct gradient, i.e. to predict, not anti-predict (since the return value of this function will appear with positive, not negative, sign in the loss function, so will be minimized). The gradient w.r.t. the non-self inputs to this function, though (i.e. - memory_shifted, sampled, softmax) will not be reversed, though. + bn_memory_shifted, sampled, softmax) will not be reversed, though. Returns: A scalar tensor, the **sum** of label smoothing loss over utterances in the batch without any normalization. """ - if reverse_gradient: - # Reversing gradient for memory_shifted puts the gradient back into + if reverse_grad: + # Reversing gradient for bn_memory_shifted puts the gradient back into # the correct sign; we reversed it in # self.discrete_bottleneck.compute_prob(), in order that # self.self.predictor_encoder can learn to predict. # (You have to read the code in reverse, to reason about # what happens to the gradients). - memory_shifted = ReverseGrad.apply(memory_shifted) + bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted) # no mask is needed for self_predictor_encoder; its CNN # layer uses left-padding only, making it causal. - predictor = self.self_predictor_encoder(memory_shifted) + predictor = self.self_predictor_encoder(bn_memory_shifted) prob = self.discrete_bottleneck.compute_prob(predictor, sampled, softmax, memory_key_padding_mask, - reverse_gradient=True) + reverse_grad=True) return prob def reverse_decoder_forward( self, - memory_shifted: torch.Tensor, + bn_memory_shifted: torch.Tensor, memory_key_padding_mask: torch.Tensor, sampled: torch.Tensor, softmax: Optional[torch.Tensor], @@ -552,11 +553,11 @@ class BidirectionalConformer(nn.Module): supervision word-sequence. Args: - memory_shifted: - It's the output of forward(), with shape [T, N, E], shifted - by one so that shifted_memory[t] == memory[t-1], as in: + bn_memory_shifted: + It's the bn_memory output of forward(), with shape [T, N, E], shifted + by one so that shifted_memory[t] == bn_memory[t-1], as in: (T, N, E) = memory.shape - memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) + bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0) memory_key_padding_mask: The padding mask from the encoder. sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` @@ -582,7 +583,7 @@ class BidirectionalConformer(nn.Module): token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ] tokens_padded = pad_sequence(token_ids_tensors, batch_first=True, - padding_value=padding_id).to(memory_shifted.device) + padding_value=padding_id).to(bn_memory_shifted.device) print("tokens_padded = ", tokens_padded) tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id) @@ -600,12 +601,12 @@ class BidirectionalConformer(nn.Module): src_key_padding_mask=tokens_key_padding_mask) # token_memory is of shape (S, N, C), if S is length of token sequence. - T = memory_shifted.shape[0] + T = bn_memory_shifted.shape[0] # the targets, here, are the hidden discrete symbols we are predicting - tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device) + tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device) hidden_predictor = self.reverse_decoder( - tgt=memory_shifted, + tgt=bn_memory_shifted, memory=token_memory, tgt_mask=tgt_mask, memory_key_padding_mask=tokens_key_padding_mask) @@ -780,7 +781,7 @@ class DiscreteBottleneck(nn.Module): torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) - def forward(self, x: Tensor, need_softmax: bool = False) -> Tuple[Tensor, Tensor, Tensor]: + def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Forward computation. Args: @@ -813,7 +814,7 @@ class DiscreteBottleneck(nn.Module): # This is a little wasteful since we already compute the softmax # inside 'flow_sample'. - softmax = x.softmax().reshape(S, N, tot_classes) if need_softmax else None + softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None x = torch_flow_sampling.flow_sample(x, interp_prob=self.interp_prob, @@ -829,15 +830,15 @@ class DiscreteBottleneck(nn.Module): self.class_probs = (self.class_probs * self.class_probs_decay + mean_class_probs * (1.0 - self.class_probs_decay)) prob_floor = self.min_prob_ratio / self.classes_per_group - self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost + self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost embedding = self.norm_out(self.linear2(x)) return (embedding, sampled, softmax) def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor], - padding_mask: Optional[Tensor], - reverse_gradient: bool = False) -> Tensor: + padding_mask: Optional[Tensor] = None, + reverse_grad: bool = False) -> Tensor: """ Compute the total probability of the sampled probabilities, given some kind of predictor x (which we assume should not have access @@ -846,7 +847,8 @@ class DiscreteBottleneck(nn.Module): x: The predictor tensor, of shape (S, N, E) where S is the sequence length, N is the batch size and E is the embedding dim - (`dim` arg to __init__()) + (`dim` arg to __init__()). This is projected from `sampled` + with a learnable matrix. sampled: A tensor of shape (S, N, C) where C is the `tot_classes` to the constructor, containing the sampled probabilities. softmax: A tensor of shape (S, N, C), this is the "smooth" version @@ -857,7 +859,7 @@ class DiscreteBottleneck(nn.Module): (batch_size, sequence_length), with True in masked positions that are to be ignored in the sum of probabilities. - reverse_gradient: If true, negate the gradient that is passed back + reverse_grad: If true, negate the gradient that is passed back to 'x' and to the modules self.pred_linear and pred_cross. This will be useful in computing a loss function that has a likelihood term with negative sign (i.e. the self-prediction). @@ -867,7 +869,7 @@ class DiscreteBottleneck(nn.Module): Returns a scalar Tensor represnting the total probability. """ - if reverse_gradient: + if reverse_grad: sampled = ReverseGrad.apply(sampled) if softmax is None: softmax = sampled @@ -906,7 +908,7 @@ class DiscreteBottleneck(nn.Module): else: tot_prob = (logprobs * softmax).sum() - if reverse_gradient: + if reverse_grad: tot_prob = ReverseGrad.apply(tot_prob) return tot_prob @@ -1789,18 +1791,18 @@ def _test_bidirectional_conformer(): eos_id=2) print("decoder logprob = ", decoder_logprob) - (T, N, E) = memory.shape - memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) + (T, N, E) = bn_memory.shape + bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0) reverse_decoder_logprob = m.reverse_decoder_forward( - memory_shifted, key_padding_mask, + bn_memory_shifted, key_padding_mask, sampled, softmax, tokens, sos_id=1, eos_id=2, padding_id=0) print("reverse decoder logprob = ", reverse_decoder_logprob) self_prediction_logprob = m.self_prediction_forward( - memory_shifted, key_padding_mask, + bn_memory_shifted, key_padding_mask, sampled, softmax) print("self prediction logprob = ", self_prediction_logprob) @@ -1809,5 +1811,67 @@ def _test_bidirectional_conformer(): loss.backward() +def _test_discrete_bottleneck(): + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + dim = 128 + tot_classes = 256 + num_groups = 4 + interp_prob = 0.8 + straight_through_scale = 0.0 # will change + need_softmax = True + + b = DiscreteBottleneck(dim, tot_classes, num_groups, + interp_prob, straight_through_scale).to(device) + + self_predictor = nn.Linear(tot_classes, dim).to(device) + + + optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())), + lr=1.0e-03, momentum=0.99) + + + for epoch in range(10): + state_dict = dict() + state_dict['b'] = b.state_dict() + state_dict['s'] = self_predictor.state_dict() + torch.save(state_dict, f'epoch-{epoch}.pt') + for i in range(1000): + # TODO: also test padding_mask + T = 300 + N = 10 + + feats = torch.randn(T, N, dim, device=device) + bn_memory, sampled, softmax = b(feats) + + predictor = feats # Could also use `bn_memory`, perhaps. But using + # predictor because it contains everything, will give + # us a bound on max information.. + prob = b.compute_prob(feats, sampled, softmax) + + sampled_reversed = ReverseGrad.apply(sampled) + predictor_reversed = self_predictor(sampled_reversed) + predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device), + predictor_reversed[:-1,:,:]), dim=0) + + self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax, + reverse_grad=True) + + normalized_prob = (prob / (T * N)).to('cpu').item() + normalized_self_prob = (self_prob / (T * N)).to('cpu').item() + + loss = -(prob - self_prob) + + normalized_loss = loss / (T * N) + + if i % 200 == 0: + print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}") + + normalized_loss.backward() + + optim.step() + optim.zero_grad() + + if __name__ == '__main__': _test_bidirectional_conformer() + _test_discrete_bottleneck()