Get bidirectional conformer to run

This commit is contained in:
Daniel Povey 2021-09-18 12:32:39 +08:00
parent a75f75bbad
commit 058fff0365

View File

@ -579,11 +579,14 @@ class BidirectionalConformer(nn.Module):
# Add both sos and eos symbols to token_ids. These will be used
# as an input, there is no harm in adding both of these.
token_ids = ([sos_id] + utt + [eos_id] for utt in token_ids)
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
tokens_padded = pad_sequence(token_ids, batch_first=True, padding_value=padding_id).to(memory.device)
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
padding_value=padding_id).to(memory_shifted.device)
tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id)
print("tokens_padded = ", tokens_padded)
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
print("tokens_key_padding_mask=", tokens_key_padding_mask)
# Let S be the length of the longest sentence (padded)
@ -597,9 +600,9 @@ class BidirectionalConformer(nn.Module):
src_key_padding_mask=tokens_key_padding_mask)
# token_memory is of shape (S, N, C), if S is length of token sequence.
T = memory.shape[0]
T = memory_shifted.shape[0]
# the targets, here, are the hidden discrete symbols we are predicting
tgt_mask = generate_square_subsequent_mask(T, device=memory.device)
tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device)
hidden_predictor = self.reverse_decoder(
tgt=memory_shifted,
@ -607,7 +610,6 @@ class BidirectionalConformer(nn.Module):
tgt_mask=tgt_mask,
memory_key_padding_mask=tokens_key_padding_mask)
total_prob = self.discrete_bottleneck.compute_prob(
hidden_predictor,
sampled,
@ -666,7 +668,7 @@ class SimpleCausalEncoderLayer(nn.Module):
class ReverseGrad(torch.autograd.Function):
def apply(ctx, x):
def forward(ctx, x):
return x
def backward(ctx, x_grad):
return -x_grad
@ -878,13 +880,12 @@ class DiscreteBottleneck(nn.Module):
pred_cross = self.pred_cross * self.pred_cross_mask
t = self.tot_classes
c = self.classes_per_group
cross_in = sampled[:,:,0:t-c] # all but the last group. Note: we could possibly
# use softmax here, but I was concerned about information
# leakage.
# all but the last group. Note: we could possibly use softmax here,
# to reduce variance, but I was concerned about information leakage.
sampled_in_part = sampled[:,:,0:t-c]
# row index of pred_cross corresponds to output, col to input -> must transpose
# before multiply.
cross_out = torch.matmul(softmax_in, pred_cross.transpose(0, 1))
cross_out = torch.matmul(sampled_in_part, pred_cross.transpose(0, 1))
# add the output of this matrix multiplication to all but the first
# group in `logprobs`. Each group is predicted based on previous
# groups.
@ -892,7 +893,7 @@ class DiscreteBottleneck(nn.Module):
(S, N, C) = logprobs.shape
logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group)
# Normalize the log-probs (so they sum to one)
logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1)
logprobs = torch.nn.functional.log_softmax(logprobs, dim=-1)
logprobs = logprobs.reshape(S, N, C)
if padding_mask is not None:
@ -1648,7 +1649,7 @@ class CausalConvolutionModule(nn.Module):
# 1D Depthwise Conv
(B, C, T) = x.shape
padding = self.kernel_size - 1
x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x),
x = torch.cat((torch.zeros(B, C, padding, dtype=x.dtype, device=x.device), x),
dim=2)
x = self.depthwise_conv(x) # <-- This convolution module does no padding,
# so we padded manually, on the left only.
@ -1780,13 +1781,27 @@ def _test_bidirectional_conformer():
# ctc_output: [N, T, C].
ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask)
decoder_loss = m.decoder_forward(memory, key_padding_mask, tokens,
decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens,
sos_id=1,
eos_id=2)
print("decoder logprob = ", decoder_logprob)
(T, N, E) = memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
reverse_decoder_logprob = m.reverse_decoder_forward(
memory_shifted, key_padding_mask,
sampled, softmax, tokens,
sos_id=1, eos_id=2, padding_id=0)
print("reverse decoder logprob = ", reverse_decoder_logprob)
self_prediction_logprob = m.self_prediction_forward(
memory_shifted, key_padding_mask,
sampled, softmax)
print("self prediction logprob = ", self_prediction_logprob)
if __name__ == '__main__':
_test_bidirectional_conformer()