From b1b21eb1e4d2d0079aa0b8ae104ccecf13e574f6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 4 Aug 2021 14:57:06 +0800 Subject: [PATCH] Fix decoder padding mask. --- .../ASR/conformer_ctc/transformer.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index a974be4e0..2722e5ba6 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -105,10 +105,7 @@ class Transformer(nn.Module): norm=encoder_norm, ) - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) - ) + self.encoder_output_layer = nn.Linear(d_model, num_classes) if num_decoder_layers > 0: if mmi_loss: @@ -274,9 +271,12 @@ class Transformer(nn.Module): device ) - # TODO: Use eos_id as ignore_id. - # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_pos(tgt) @@ -339,9 +339,9 @@ class Transformer(nn.Module): device ) - # TODO: Use eos_id as ignore_id. - # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + tgt_key_padding_mask[:, 0] = False + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) tgt = self.decoder_pos(tgt)