Fix transformer decoder layer (#1995)

This commit is contained in:
Fangjun Kuang 2025-07-18 20:12:29 +08:00 committed by GitHub
parent 5fe13078cc
commit 34fc1fdf0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 8 additions and 0 deletions

View File

@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.

View File

@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.

View File

@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.

View File

@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.

View File

@ -550,6 +550,7 @@ class TransformerDecoderLayer(nn.Module):
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
warmup: float = 1.0,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.

View File

@ -537,6 +537,7 @@ class TransformerDecoderLayer(nn.Module):
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.

View File

@ -567,6 +567,7 @@ class TransformerDecoderLayer(nn.Module):
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.

View File

@ -612,6 +612,7 @@ class TransformerDecoderLayer(nn.Module):
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
warmup: float = 1.0,
**kwargs,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.