mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
fix attention mask
This commit is contained in:
parent
a413ccfcec
commit
c87f55671a
@ -366,14 +366,14 @@ class TransformerDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
tgt = ys_in_pad
|
tgt = ys_in_pad
|
||||||
# tgt_mask: (B, 1, L)
|
# tgt_mask: (B, 1, L)
|
||||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device)
|
||||||
# m: (1, L, L)
|
# m: (1, L, L)
|
||||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||||
# tgt_mask: (B, L, L)
|
# tgt_mask: (B, L, L)
|
||||||
tgt_mask = tgt_mask & m
|
tgt_mask = tgt_mask | (~m)
|
||||||
|
|
||||||
memory = hs_pad
|
memory = hs_pad
|
||||||
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(memory.device)
|
memory_mask = make_pad_mask(hlens)[:, None, :].to(memory.device)
|
||||||
|
|
||||||
tgt = self.embed(tgt)
|
tgt = self.embed(tgt)
|
||||||
tgt = self.pos(tgt)
|
tgt = self.pos(tgt)
|
||||||
@ -721,15 +721,17 @@ def _test_attention_decoder_model():
|
|||||||
attention_dim=192,
|
attention_dim=192,
|
||||||
nhead=8,
|
nhead=8,
|
||||||
feedforward_dim=2048,
|
feedforward_dim=2048,
|
||||||
dropout=0,
|
dropout=0.1,
|
||||||
sos_id=1,
|
sos_id=1,
|
||||||
eos_id=1,
|
eos_id=1,
|
||||||
ignore_id=-1,
|
ignore_id=-1,
|
||||||
)
|
)
|
||||||
encoder_out = torch.randn(2, 100, 384)
|
m.eval()
|
||||||
encoder_out_lens = torch.full((2,), 100)
|
encoder_out = torch.randn(2, 50, 384)
|
||||||
token_ids = [[1, 2], [2, 3, 10]]
|
encoder_out_lens = torch.full((2,), 50)
|
||||||
|
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
||||||
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
|
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user