fix attention mask

This commit is contained in:
yaozengwei 2023-01-10 23:17:38 +08:00
parent a413ccfcec
commit c87f55671a

View File

@ -366,14 +366,14 @@ class TransformerDecoder(nn.Module):
""" """
tgt = ys_in_pad tgt = ys_in_pad
# tgt_mask: (B, 1, L) # tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device)
# m: (1, L, L) # m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L) # tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m tgt_mask = tgt_mask | (~m)
memory = hs_pad memory = hs_pad
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(memory.device) memory_mask = make_pad_mask(hlens)[:, None, :].to(memory.device)
tgt = self.embed(tgt) tgt = self.embed(tgt)
tgt = self.pos(tgt) tgt = self.pos(tgt)
@ -721,15 +721,17 @@ def _test_attention_decoder_model():
attention_dim=192, attention_dim=192,
nhead=8, nhead=8,
feedforward_dim=2048, feedforward_dim=2048,
dropout=0, dropout=0.1,
sos_id=1, sos_id=1,
eos_id=1, eos_id=1,
ignore_id=-1, ignore_id=-1,
) )
encoder_out = torch.randn(2, 100, 384) m.eval()
encoder_out_lens = torch.full((2,), 100) encoder_out = torch.randn(2, 50, 384)
token_ids = [[1, 2], [2, 3, 10]] encoder_out_lens = torch.full((2,), 50)
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids) loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
print(loss)
if __name__ == "__main__": if __name__ == "__main__":