Fix decoder padding mask.

This commit is contained in:
Fangjun Kuang 2021-08-04 14:57:06 +08:00
parent a6d9b3c9ab
commit b1b21eb1e4

View File

@ -105,10 +105,7 @@ class Transformer(nn.Module):
norm=encoder_norm, norm=encoder_norm,
) )
# TODO(fangjun): remove dropout self.encoder_output_layer = nn.Linear(d_model, num_classes)
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
)
if num_decoder_layers > 0: if num_decoder_layers > 0:
if mmi_loss: if mmi_loss:
@ -274,9 +271,12 @@ class Transformer(nn.Module):
device device
) )
# TODO: Use eos_id as ignore_id. tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) # We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
tgt = self.decoder_pos(tgt) tgt = self.decoder_pos(tgt)
@ -339,9 +339,9 @@ class Transformer(nn.Module):
device device
) )
# TODO: Use eos_id as ignore_id. tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask[:, 0] = False
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
tgt = self.decoder_pos(tgt) tgt = self.decoder_pos(tgt)