diff --git a/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py b/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py index dbaa2ac98..2cf3fb83e 100644 --- a/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py @@ -366,14 +366,14 @@ class TransformerDecoder(nn.Module): """ tgt = ys_in_pad # tgt_mask: (B, 1, L) - tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device) # m: (1, L, L) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) - tgt_mask = tgt_mask & m + tgt_mask = tgt_mask | (~m) memory = hs_pad - memory_mask = (~make_pad_mask(hlens))[:, None, :].to(memory.device) + memory_mask = make_pad_mask(hlens)[:, None, :].to(memory.device) tgt = self.embed(tgt) tgt = self.pos(tgt) @@ -721,15 +721,17 @@ def _test_attention_decoder_model(): attention_dim=192, nhead=8, feedforward_dim=2048, - dropout=0, + dropout=0.1, sos_id=1, eos_id=1, ignore_id=-1, ) - encoder_out = torch.randn(2, 100, 384) - encoder_out_lens = torch.full((2,), 100) - token_ids = [[1, 2], [2, 3, 10]] + m.eval() + encoder_out = torch.randn(2, 50, 384) + encoder_out_lens = torch.full((2,), 50) + token_ids = [[1, 2, 3, 4], [2, 3, 10]] loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids) + print(loss) if __name__ == "__main__":