from local

This commit is contained in:
dohe0342 2023-02-25 15:20:21 +09:00
parent ed44167dab
commit 48b2dddcd9
2 changed files with 23 additions and 0 deletions

View File

@ -225,6 +225,29 @@ def get_params() -> AttributeDict:
return params
def greedy_search(
ctc_probs: torch.Tensor,
mask: torch.Tensor,
) -> List[List[int]]:
"""Apply CTC greedy search
Args:
ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
mask (torch.Tensor): (batch, max_len)
Returns:
best path result
"""
_, max_index = ctc_probs.max(2) # (B, maxlen)
max_index = max_index.masked_fill_(mask, 0) # (B, maxlen)
ret_hyps = []
for hyp in max_index:
hyp = torch.unique_consecutive(hyp)
hyp = hyp[hyp > 0].tolist()
ret_hyps.append(hyp)
return ret_hyps
def ctc_greedy_search(
ctc_probs: torch.Tensor,
mask: torch.Tensor,