Refactor ctc greedy search. (#1691)

Use torch.unique_consecutive() to avoid reinventing the wheel.
This commit is contained in:
Fangjun Kuang 2024-07-15 12:01:47 +08:00 committed by GitHub
parent d47c078286
commit 2e13298717
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 16 deletions

View File

@ -1475,21 +1475,10 @@ def rescore_with_rnn_lm(
return ans return ans
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def ctc_greedy_search( def ctc_greedy_search(
ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_id: int = 0,
) -> List[List[int]]: ) -> List[List[int]]:
"""CTC greedy search. """CTC greedy search.
@ -1501,6 +1490,10 @@ def ctc_greedy_search(
""" """
batch = ctc_output.shape[0] batch = ctc_output.shape[0]
index = ctc_output.argmax(dim=-1) # (batch, seq_len) index = ctc_output.argmax(dim=-1) # (batch, seq_len)
hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)] hyps = [
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch)
]
hyps = [h[h != blank_id].tolist() for h in hyps]
return hyps return hyps

42
test/test_ctc_greedy_search.py Executable file
View File

@ -0,0 +1,42 @@
#!/usr/bin/env python3
import torch
from icefall.decode import ctc_greedy_search
def test():
log_probs = torch.tensor(
[
[
[10, 1, 2, 1, 1, 3, 2, 3],
[10, 3, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 1, 10, 1, 1, 3, 2, 3],
[10, 1, 1, 1, 1, 3, 2, 3],
[1, 1, 1, 10, 1, 3, 2, 3],
],
[
[10, 1, 2, 1, 1, 3, 2, 3],
[10, 3, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 1, 10, 1, 1, 3, 2, 3],
[10, 1, 1, 1, 1, 3, 2, 3],
[1, 1, 1, 10, 1, 3, 2, 3],
],
],
dtype=torch.float32,
).log_softmax(dim=-1)
log_probs_length = torch.tensor([7, 6])
hyps = ctc_greedy_search(log_probs, log_probs_length)
assert hyps[0] == [1, 2, 3], hyps[0]
assert hyps[1] == [1, 2], hyps[1]
if __name__ == "__main__":
test()