mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Fix transformer decoder layer (#1995)
This commit is contained in:
parent
11df2a83fc
commit
2d8e3fd858
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -550,6 +550,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -537,6 +537,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -567,6 +567,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -612,6 +612,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user