mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
fix attention mask
This commit is contained in:
parent
a413ccfcec
commit
c87f55671a
@ -366,14 +366,14 @@ class TransformerDecoder(nn.Module):
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
tgt_mask = tgt_mask | (~m)
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(memory.device)
|
||||
memory_mask = make_pad_mask(hlens)[:, None, :].to(memory.device)
|
||||
|
||||
tgt = self.embed(tgt)
|
||||
tgt = self.pos(tgt)
|
||||
@ -721,15 +721,17 @@ def _test_attention_decoder_model():
|
||||
attention_dim=192,
|
||||
nhead=8,
|
||||
feedforward_dim=2048,
|
||||
dropout=0,
|
||||
dropout=0.1,
|
||||
sos_id=1,
|
||||
eos_id=1,
|
||||
ignore_id=-1,
|
||||
)
|
||||
encoder_out = torch.randn(2, 100, 384)
|
||||
encoder_out_lens = torch.full((2,), 100)
|
||||
token_ids = [[1, 2], [2, 3, 10]]
|
||||
m.eval()
|
||||
encoder_out = torch.randn(2, 50, 384)
|
||||
encoder_out_lens = torch.full((2,), 50)
|
||||
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
||||
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
|
||||
print(loss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user