mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
remove unnecessary func
This commit is contained in:
parent
833957f645
commit
f3618c989a
@ -241,34 +241,6 @@ def ctc_greedy_search(
|
||||
return hyps, scores
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
See description of make_non_pad_mask.
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): Batch of lengths (B,).
|
||||
Returns:
|
||||
torch.Tensor: Mask tensor containing indices of padded part.
|
||||
|
||||
Examples:
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
batch_size = lengths.size(0)
|
||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
||||
seq_range = torch.arange(
|
||||
0, max_len, dtype=torch.int64, device=lengths.device
|
||||
)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
new_hyp: List[int] = []
|
||||
cur = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user