mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix decoder padding mask.
This commit is contained in:
parent
a6d9b3c9ab
commit
b1b21eb1e4
@ -105,10 +105,7 @@ class Transformer(nn.Module):
|
||||
norm=encoder_norm,
|
||||
)
|
||||
|
||||
# TODO(fangjun): remove dropout
|
||||
self.encoder_output_layer = nn.Sequential(
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
self.encoder_output_layer = nn.Linear(d_model, num_classes)
|
||||
|
||||
if num_decoder_layers > 0:
|
||||
if mmi_loss:
|
||||
@ -274,9 +271,12 @@ class Transformer(nn.Module):
|
||||
device
|
||||
)
|
||||
|
||||
# TODO: Use eos_id as ignore_id.
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
@ -339,9 +339,9 @@ class Transformer(nn.Module):
|
||||
device
|
||||
)
|
||||
|
||||
# TODO: Use eos_id as ignore_id.
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
|
Loading…
x
Reference in New Issue
Block a user