mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Remove changes to files outside of relevant recipes
This commit is contained in:
parent
36fc1f1d1e
commit
7231cf44aa
@ -545,7 +545,6 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -612,7 +612,6 @@ class TransformerDecoderLayer(nn.Module):
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -1391,20 +1391,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
||||
return concat(ragged, eos_id, direction="right")
|
||||
|
||||
|
||||
def make_pad_mask(
|
||||
lengths: torch.Tensor,
|
||||
max_len: int = 0,
|
||||
pad_left: bool = False,
|
||||
) -> torch.Tensor:
|
||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lengths:
|
||||
A 1-D tensor containing sentence lengths.
|
||||
max_len:
|
||||
The length of masks.
|
||||
pad_left:
|
||||
If ``False`` (default), padding is on the right.
|
||||
If ``True``, padding is on the left.
|
||||
Returns:
|
||||
Return a 2-D bool tensor, where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
@ -1421,14 +1414,9 @@ def make_pad_mask(
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||
|
||||
if pad_left:
|
||||
mask = expanded_lengths < (max_len - lengths).unsqueeze(1)
|
||||
else:
|
||||
mask = expanded_lengths >= lengths.unsqueeze(-1)
|
||||
|
||||
return mask
|
||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||
|
||||
|
||||
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user