mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add merge_tokens for ctc forced alignment (#1649)
This commit is contained in:
parent
ec0389a3c1
commit
13f55d0735
@ -4,3 +4,4 @@ from .prepare_lang import (
|
||||
make_lexicon_fst_with_silence,
|
||||
)
|
||||
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
|
||||
from .utils import merge_tokens
|
||||
|
87
icefall/ctc/test_utils.py
Executable file
87
icefall/ctc/test_utils.py
Executable file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from typing import List
|
||||
|
||||
from utils import TokenSpan, merge_tokens
|
||||
|
||||
|
||||
def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
|
||||
"""Compute start and end frames of each token from the given alignment.
|
||||
|
||||
Args:
|
||||
alignment:
|
||||
A list of token IDs.
|
||||
blank_id:
|
||||
ID of the blank.
|
||||
Returns:
|
||||
Return a list of TokenSpan.
|
||||
"""
|
||||
ans = []
|
||||
last_token = None
|
||||
last_i = None
|
||||
|
||||
# import pdb
|
||||
|
||||
# pdb.set_trace()
|
||||
for i, token in enumerate(alignment):
|
||||
if token == blank:
|
||||
if last_token is None or last_token == token:
|
||||
continue
|
||||
|
||||
# end of the last token
|
||||
span = TokenSpan(token=last_token, start=last_i, end=i)
|
||||
ans.append(span)
|
||||
last_token = None
|
||||
last_i = None
|
||||
continue
|
||||
|
||||
# The current token is not a blank
|
||||
if last_token is None or last_token == blank:
|
||||
last_token = token
|
||||
last_i = i
|
||||
continue
|
||||
|
||||
if last_token == token:
|
||||
continue
|
||||
|
||||
# end of the last token and start of the current token
|
||||
span = TokenSpan(token=last_token, start=last_i, end=i)
|
||||
last_token = token
|
||||
last_i = i
|
||||
ans.append(span)
|
||||
|
||||
if last_token is not None:
|
||||
assert last_i is not None, (last_i, last_token)
|
||||
span = TokenSpan(token=last_token, start=last_i, end=len(alignment))
|
||||
# Note for the last token, its end is larger than len(alignment)-1
|
||||
ans.append(span)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def test_merge_tokens():
|
||||
data_list = [
|
||||
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
|
||||
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
|
||||
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
|
||||
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
|
||||
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
|
||||
[0, 1, 2, 3, 0],
|
||||
[1, 2, 3, 0],
|
||||
[0, 1, 2, 3],
|
||||
[1, 2, 3],
|
||||
]
|
||||
|
||||
for data in data_list:
|
||||
span1 = merge_tokens(data)
|
||||
span2 = inefficient_merge_tokens(data)
|
||||
assert span1 == span2, (data, span1, span2)
|
||||
|
||||
|
||||
def main():
|
||||
test_merge_tokens()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
52
icefall/ctc/utils.py
Normal file
52
icefall/ctc/utils.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenSpan:
|
||||
# ID of the token
|
||||
token: int
|
||||
|
||||
# Start frame of this token in the output log_prob
|
||||
start: int
|
||||
|
||||
# End frame of this token in the output log_prob
|
||||
end: int
|
||||
|
||||
|
||||
# See also
|
||||
# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96
|
||||
# We use torchaudio as a reference while implementing this function
|
||||
def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
|
||||
"""Compute start and end frames of each token from the given alignment.
|
||||
|
||||
Args:
|
||||
alignment:
|
||||
A list of token IDs.
|
||||
blank_id:
|
||||
ID of the blank.
|
||||
Returns:
|
||||
Return a list of TokenSpan.
|
||||
"""
|
||||
alignment_tensor = torch.tensor(alignment, dtype=torch.int32)
|
||||
|
||||
diff = torch.diff(
|
||||
alignment_tensor,
|
||||
prepend=torch.tensor([-1]),
|
||||
append=torch.tensor([-1]),
|
||||
)
|
||||
|
||||
non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist()
|
||||
|
||||
ans = []
|
||||
for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]):
|
||||
token = alignment[start]
|
||||
if token == blank:
|
||||
continue
|
||||
span = TokenSpan(token=token, start=start, end=end)
|
||||
ans.append(span)
|
||||
return ans
|
Loading…
x
Reference in New Issue
Block a user