mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
remove unnecessary func
This commit is contained in:
parent
833957f645
commit
f3618c989a
@ -241,34 +241,6 @@ def ctc_greedy_search(
|
|||||||
return hyps, scores
|
return hyps, scores
|
||||||
|
|
||||||
|
|
||||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
||||||
"""Make mask tensor containing indices of padded part.
|
|
||||||
|
|
||||||
See description of make_non_pad_mask.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lengths (torch.Tensor): Batch of lengths (B,).
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Mask tensor containing indices of padded part.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> lengths = [5, 3, 2]
|
|
||||||
>>> make_pad_mask(lengths)
|
|
||||||
masks = [[0, 0, 0, 0 ,0],
|
|
||||||
[0, 0, 0, 1, 1],
|
|
||||||
[0, 0, 1, 1, 1]]
|
|
||||||
"""
|
|
||||||
batch_size = lengths.size(0)
|
|
||||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
||||||
seq_range = torch.arange(
|
|
||||||
0, max_len, dtype=torch.int64, device=lengths.device
|
|
||||||
)
|
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
||||||
seq_length_expand = lengths.unsqueeze(-1)
|
|
||||||
mask = seq_range_expand >= seq_length_expand
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||||
new_hyp: List[int] = []
|
new_hyp: List[int] = []
|
||||||
cur = 0
|
cur = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user