mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
ed44167dab
commit
48b2dddcd9
Binary file not shown.
@ -225,6 +225,29 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def greedy_search(
|
||||
ctc_probs: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Apply CTC greedy search
|
||||
Args:
|
||||
ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
|
||||
mask (torch.Tensor): (batch, max_len)
|
||||
Returns:
|
||||
best path result
|
||||
"""
|
||||
|
||||
_, max_index = ctc_probs.max(2) # (B, maxlen)
|
||||
max_index = max_index.masked_fill_(mask, 0) # (B, maxlen)
|
||||
|
||||
ret_hyps = []
|
||||
for hyp in max_index:
|
||||
hyp = torch.unique_consecutive(hyp)
|
||||
hyp = hyp[hyp > 0].tolist()
|
||||
ret_hyps.append(hyp)
|
||||
return ret_hyps
|
||||
|
||||
|
||||
def ctc_greedy_search(
|
||||
ctc_probs: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user