diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index 4aca39ec0..971fc578a 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -579,11 +579,14 @@ class BidirectionalConformer(nn.Module): # Add both sos and eos symbols to token_ids. These will be used # as an input, there is no harm in adding both of these. - token_ids = ([sos_id] + utt + [eos_id] for utt in token_ids) + token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ] - tokens_padded = pad_sequence(token_ids, batch_first=True, padding_value=padding_id).to(memory.device) + tokens_padded = pad_sequence(token_ids_tensors, batch_first=True, + padding_value=padding_id).to(memory_shifted.device) - tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id) + print("tokens_padded = ", tokens_padded) + tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id) + print("tokens_key_padding_mask=", tokens_key_padding_mask) # Let S be the length of the longest sentence (padded) @@ -597,9 +600,9 @@ class BidirectionalConformer(nn.Module): src_key_padding_mask=tokens_key_padding_mask) # token_memory is of shape (S, N, C), if S is length of token sequence. - T = memory.shape[0] + T = memory_shifted.shape[0] # the targets, here, are the hidden discrete symbols we are predicting - tgt_mask = generate_square_subsequent_mask(T, device=memory.device) + tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device) hidden_predictor = self.reverse_decoder( tgt=memory_shifted, @@ -607,7 +610,6 @@ class BidirectionalConformer(nn.Module): tgt_mask=tgt_mask, memory_key_padding_mask=tokens_key_padding_mask) - total_prob = self.discrete_bottleneck.compute_prob( hidden_predictor, sampled, @@ -666,7 +668,7 @@ class SimpleCausalEncoderLayer(nn.Module): class ReverseGrad(torch.autograd.Function): - def apply(ctx, x): + def forward(ctx, x): return x def backward(ctx, x_grad): return -x_grad @@ -878,13 +880,12 @@ class DiscreteBottleneck(nn.Module): pred_cross = self.pred_cross * self.pred_cross_mask t = self.tot_classes c = self.classes_per_group - - cross_in = sampled[:,:,0:t-c] # all but the last group. Note: we could possibly - # use softmax here, but I was concerned about information - # leakage. + # all but the last group. Note: we could possibly use softmax here, + # to reduce variance, but I was concerned about information leakage. + sampled_in_part = sampled[:,:,0:t-c] # row index of pred_cross corresponds to output, col to input -> must transpose # before multiply. - cross_out = torch.matmul(softmax_in, pred_cross.transpose(0, 1)) + cross_out = torch.matmul(sampled_in_part, pred_cross.transpose(0, 1)) # add the output of this matrix multiplication to all but the first # group in `logprobs`. Each group is predicted based on previous # groups. @@ -892,7 +893,7 @@ class DiscreteBottleneck(nn.Module): (S, N, C) = logprobs.shape logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group) # Normalize the log-probs (so they sum to one) - logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1) + logprobs = torch.nn.functional.log_softmax(logprobs, dim=-1) logprobs = logprobs.reshape(S, N, C) if padding_mask is not None: @@ -1648,7 +1649,7 @@ class CausalConvolutionModule(nn.Module): # 1D Depthwise Conv (B, C, T) = x.shape padding = self.kernel_size - 1 - x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x), + x = torch.cat((torch.zeros(B, C, padding, dtype=x.dtype, device=x.device), x), dim=2) x = self.depthwise_conv(x) # <-- This convolution module does no padding, # so we padded manually, on the left only. @@ -1780,13 +1781,27 @@ def _test_bidirectional_conformer(): # ctc_output: [N, T, C]. ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask) - decoder_loss = m.decoder_forward(memory, key_padding_mask, tokens, + decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens, sos_id=1, eos_id=2) + print("decoder logprob = ", decoder_logprob) (T, N, E) = memory.shape memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) + reverse_decoder_logprob = m.reverse_decoder_forward( + memory_shifted, key_padding_mask, + sampled, softmax, tokens, + sos_id=1, eos_id=2, padding_id=0) + + print("reverse decoder logprob = ", reverse_decoder_logprob) + + self_prediction_logprob = m.self_prediction_forward( + memory_shifted, key_padding_mask, + sampled, softmax) + + print("self prediction logprob = ", self_prediction_logprob) + if __name__ == '__main__': _test_bidirectional_conformer()