mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix transformer decoder layer
This commit is contained in:
parent
9fd0f2dc1d
commit
e2b29afd1d
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
@ -550,6 +550,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
@ -537,6 +537,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
@ -567,6 +567,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
@ -612,6 +612,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user