mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
Merge branch 'master' into multi_ja_en_mls_english_clean
This commit is contained in:
commit
d74e2322e0
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -550,6 +550,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -537,6 +537,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -567,6 +567,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory_mask: Optional[torch.Tensor] = None,
|
memory_mask: Optional[torch.Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -612,6 +612,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
|||||||
@ -124,6 +124,7 @@ def _compute_mmi_loss_exact_non_optimized(
|
|||||||
den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
|
den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
|
||||||
|
|
||||||
tot_scores = num_tot_scores - den_scale * den_tot_scores
|
tot_scores = num_tot_scores - den_scale * den_tot_scores
|
||||||
|
tot_scores = tot_scores.masked_fill(torch.isinf(tot_scores), 0.0)
|
||||||
|
|
||||||
loss = -1 * tot_scores.sum()
|
loss = -1 * tot_scores.sum()
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@ -1391,13 +1391,20 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
|||||||
return concat(ragged, eos_id, direction="right")
|
return concat(ragged, eos_id, direction="right")
|
||||||
|
|
||||||
|
|
||||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
def make_pad_mask(
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
max_len: int = 0,
|
||||||
|
pad_left: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lengths:
|
lengths:
|
||||||
A 1-D tensor containing sentence lengths.
|
A 1-D tensor containing sentence lengths.
|
||||||
max_len:
|
max_len:
|
||||||
The length of masks.
|
The length of masks.
|
||||||
|
pad_left:
|
||||||
|
If ``False`` (default), padding is on the right.
|
||||||
|
If ``True``, padding is on the left.
|
||||||
Returns:
|
Returns:
|
||||||
Return a 2-D bool tensor, where masked positions
|
Return a 2-D bool tensor, where masked positions
|
||||||
are filled with `True` and non-masked positions are
|
are filled with `True` and non-masked positions are
|
||||||
@ -1414,9 +1421,14 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|||||||
max_len = max(max_len, lengths.max())
|
max_len = max(max_len, lengths.max())
|
||||||
n = lengths.size(0)
|
n = lengths.size(0)
|
||||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||||
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||||
|
|
||||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
if pad_left:
|
||||||
|
mask = expanded_lengths < (max_len - lengths).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
mask = expanded_lengths >= lengths.unsqueeze(-1)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user