mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix transformer decoder layer (#1995)
This commit is contained in:
parent
5fe13078cc
commit
34fc1fdf0d
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
@ -550,6 +550,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
@ -537,6 +537,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
@ -567,6 +567,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
@ -612,6 +612,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user