from local
This commit is contained in:
parent
ed44167dab
commit
48b2dddcd9
Binary file not shown.
@ -225,6 +225,29 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
ctc_probs: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Apply CTC greedy search
|
||||||
|
Args:
|
||||||
|
ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
|
||||||
|
mask (torch.Tensor): (batch, max_len)
|
||||||
|
Returns:
|
||||||
|
best path result
|
||||||
|
"""
|
||||||
|
|
||||||
|
_, max_index = ctc_probs.max(2) # (B, maxlen)
|
||||||
|
max_index = max_index.masked_fill_(mask, 0) # (B, maxlen)
|
||||||
|
|
||||||
|
ret_hyps = []
|
||||||
|
for hyp in max_index:
|
||||||
|
hyp = torch.unique_consecutive(hyp)
|
||||||
|
hyp = hyp[hyp > 0].tolist()
|
||||||
|
ret_hyps.append(hyp)
|
||||||
|
return ret_hyps
|
||||||
|
|
||||||
|
|
||||||
def ctc_greedy_search(
|
def ctc_greedy_search(
|
||||||
ctc_probs: torch.Tensor,
|
ctc_probs: torch.Tensor,
|
||||||
mask: torch.Tensor,
|
mask: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user