mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Get bidirectional conformer to run
This commit is contained in:
parent
a75f75bbad
commit
058fff0365
@ -579,11 +579,14 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
# Add both sos and eos symbols to token_ids. These will be used
|
||||
# as an input, there is no harm in adding both of these.
|
||||
token_ids = ([sos_id] + utt + [eos_id] for utt in token_ids)
|
||||
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
|
||||
|
||||
tokens_padded = pad_sequence(token_ids, batch_first=True, padding_value=padding_id).to(memory.device)
|
||||
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
|
||||
padding_value=padding_id).to(memory_shifted.device)
|
||||
|
||||
tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id)
|
||||
print("tokens_padded = ", tokens_padded)
|
||||
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
|
||||
print("tokens_key_padding_mask=", tokens_key_padding_mask)
|
||||
|
||||
|
||||
# Let S be the length of the longest sentence (padded)
|
||||
@ -597,9 +600,9 @@ class BidirectionalConformer(nn.Module):
|
||||
src_key_padding_mask=tokens_key_padding_mask)
|
||||
# token_memory is of shape (S, N, C), if S is length of token sequence.
|
||||
|
||||
T = memory.shape[0]
|
||||
T = memory_shifted.shape[0]
|
||||
# the targets, here, are the hidden discrete symbols we are predicting
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=memory.device)
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device)
|
||||
|
||||
hidden_predictor = self.reverse_decoder(
|
||||
tgt=memory_shifted,
|
||||
@ -607,7 +610,6 @@ class BidirectionalConformer(nn.Module):
|
||||
tgt_mask=tgt_mask,
|
||||
memory_key_padding_mask=tokens_key_padding_mask)
|
||||
|
||||
|
||||
total_prob = self.discrete_bottleneck.compute_prob(
|
||||
hidden_predictor,
|
||||
sampled,
|
||||
@ -666,7 +668,7 @@ class SimpleCausalEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class ReverseGrad(torch.autograd.Function):
|
||||
def apply(ctx, x):
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
def backward(ctx, x_grad):
|
||||
return -x_grad
|
||||
@ -878,13 +880,12 @@ class DiscreteBottleneck(nn.Module):
|
||||
pred_cross = self.pred_cross * self.pred_cross_mask
|
||||
t = self.tot_classes
|
||||
c = self.classes_per_group
|
||||
|
||||
cross_in = sampled[:,:,0:t-c] # all but the last group. Note: we could possibly
|
||||
# use softmax here, but I was concerned about information
|
||||
# leakage.
|
||||
# all but the last group. Note: we could possibly use softmax here,
|
||||
# to reduce variance, but I was concerned about information leakage.
|
||||
sampled_in_part = sampled[:,:,0:t-c]
|
||||
# row index of pred_cross corresponds to output, col to input -> must transpose
|
||||
# before multiply.
|
||||
cross_out = torch.matmul(softmax_in, pred_cross.transpose(0, 1))
|
||||
cross_out = torch.matmul(sampled_in_part, pred_cross.transpose(0, 1))
|
||||
# add the output of this matrix multiplication to all but the first
|
||||
# group in `logprobs`. Each group is predicted based on previous
|
||||
# groups.
|
||||
@ -892,7 +893,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
(S, N, C) = logprobs.shape
|
||||
logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
# Normalize the log-probs (so they sum to one)
|
||||
logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1)
|
||||
logprobs = torch.nn.functional.log_softmax(logprobs, dim=-1)
|
||||
logprobs = logprobs.reshape(S, N, C)
|
||||
|
||||
if padding_mask is not None:
|
||||
@ -1648,7 +1649,7 @@ class CausalConvolutionModule(nn.Module):
|
||||
# 1D Depthwise Conv
|
||||
(B, C, T) = x.shape
|
||||
padding = self.kernel_size - 1
|
||||
x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x),
|
||||
x = torch.cat((torch.zeros(B, C, padding, dtype=x.dtype, device=x.device), x),
|
||||
dim=2)
|
||||
x = self.depthwise_conv(x) # <-- This convolution module does no padding,
|
||||
# so we padded manually, on the left only.
|
||||
@ -1780,13 +1781,27 @@ def _test_bidirectional_conformer():
|
||||
# ctc_output: [N, T, C].
|
||||
ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask)
|
||||
|
||||
decoder_loss = m.decoder_forward(memory, key_padding_mask, tokens,
|
||||
decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens,
|
||||
sos_id=1,
|
||||
eos_id=2)
|
||||
print("decoder logprob = ", decoder_logprob)
|
||||
|
||||
(T, N, E) = memory.shape
|
||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
||||
|
||||
reverse_decoder_logprob = m.reverse_decoder_forward(
|
||||
memory_shifted, key_padding_mask,
|
||||
sampled, softmax, tokens,
|
||||
sos_id=1, eos_id=2, padding_id=0)
|
||||
|
||||
print("reverse decoder logprob = ", reverse_decoder_logprob)
|
||||
|
||||
self_prediction_logprob = m.self_prediction_forward(
|
||||
memory_shifted, key_padding_mask,
|
||||
sampled, softmax)
|
||||
|
||||
print("self prediction logprob = ", self_prediction_logprob)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_bidirectional_conformer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user