mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Refactor ctc greedy search. (#1691)
Use torch.unique_consecutive() to avoid reinventing the wheel.
This commit is contained in:
parent
d47c078286
commit
2e13298717
@ -1475,21 +1475,10 @@ def rescore_with_rnn_lm(
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
|
||||||
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
|
||||||
new_hyp: List[int] = []
|
|
||||||
cur = 0
|
|
||||||
while cur < len(hyp):
|
|
||||||
if hyp[cur] != 0:
|
|
||||||
new_hyp.append(hyp[cur])
|
|
||||||
prev = cur
|
|
||||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
|
||||||
cur += 1
|
|
||||||
return new_hyp
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_greedy_search(
|
def ctc_greedy_search(
|
||||||
ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor
|
ctc_output: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
blank_id: int = 0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""CTC greedy search.
|
"""CTC greedy search.
|
||||||
|
|
||||||
@ -1501,6 +1490,10 @@ def ctc_greedy_search(
|
|||||||
"""
|
"""
|
||||||
batch = ctc_output.shape[0]
|
batch = ctc_output.shape[0]
|
||||||
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
|
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
|
||||||
hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)]
|
hyps = [
|
||||||
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch)
|
||||||
|
]
|
||||||
|
|
||||||
|
hyps = [h[h != blank_id].tolist() for h in hyps]
|
||||||
|
|
||||||
return hyps
|
return hyps
|
||||||
|
42
test/test_ctc_greedy_search.py
Executable file
42
test/test_ctc_greedy_search.py
Executable file
@ -0,0 +1,42 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.decode import ctc_greedy_search
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
log_probs = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[10, 1, 2, 1, 1, 3, 2, 3],
|
||||||
|
[10, 3, 2, 2, 1, 3, 2, 3],
|
||||||
|
[1, 10, 2, 2, 1, 3, 2, 3],
|
||||||
|
[1, 10, 2, 2, 1, 3, 2, 3],
|
||||||
|
[1, 1, 10, 1, 1, 3, 2, 3],
|
||||||
|
[10, 1, 1, 1, 1, 3, 2, 3],
|
||||||
|
[1, 1, 1, 10, 1, 3, 2, 3],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[10, 1, 2, 1, 1, 3, 2, 3],
|
||||||
|
[10, 3, 2, 2, 1, 3, 2, 3],
|
||||||
|
[1, 10, 2, 2, 1, 3, 2, 3],
|
||||||
|
[1, 10, 2, 2, 1, 3, 2, 3],
|
||||||
|
[1, 1, 10, 1, 1, 3, 2, 3],
|
||||||
|
[10, 1, 1, 1, 1, 3, 2, 3],
|
||||||
|
[1, 1, 1, 10, 1, 3, 2, 3],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
).log_softmax(dim=-1)
|
||||||
|
|
||||||
|
log_probs_length = torch.tensor([7, 6])
|
||||||
|
|
||||||
|
hyps = ctc_greedy_search(log_probs, log_probs_length)
|
||||||
|
|
||||||
|
assert hyps[0] == [1, 2, 3], hyps[0]
|
||||||
|
assert hyps[1] == [1, 2], hyps[1]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test()
|
Loading…
x
Reference in New Issue
Block a user