mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Add merge_tokens for ctc forced alignment (#1649)
This commit is contained in:
parent
ec0389a3c1
commit
13f55d0735
@ -4,3 +4,4 @@ from .prepare_lang import (
|
|||||||
make_lexicon_fst_with_silence,
|
make_lexicon_fst_with_silence,
|
||||||
)
|
)
|
||||||
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
|
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
|
||||||
|
from .utils import merge_tokens
|
||||||
|
87
icefall/ctc/test_utils.py
Executable file
87
icefall/ctc/test_utils.py
Executable file
@ -0,0 +1,87 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from utils import TokenSpan, merge_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
|
||||||
|
"""Compute start and end frames of each token from the given alignment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignment:
|
||||||
|
A list of token IDs.
|
||||||
|
blank_id:
|
||||||
|
ID of the blank.
|
||||||
|
Returns:
|
||||||
|
Return a list of TokenSpan.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
last_token = None
|
||||||
|
last_i = None
|
||||||
|
|
||||||
|
# import pdb
|
||||||
|
|
||||||
|
# pdb.set_trace()
|
||||||
|
for i, token in enumerate(alignment):
|
||||||
|
if token == blank:
|
||||||
|
if last_token is None or last_token == token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# end of the last token
|
||||||
|
span = TokenSpan(token=last_token, start=last_i, end=i)
|
||||||
|
ans.append(span)
|
||||||
|
last_token = None
|
||||||
|
last_i = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The current token is not a blank
|
||||||
|
if last_token is None or last_token == blank:
|
||||||
|
last_token = token
|
||||||
|
last_i = i
|
||||||
|
continue
|
||||||
|
|
||||||
|
if last_token == token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# end of the last token and start of the current token
|
||||||
|
span = TokenSpan(token=last_token, start=last_i, end=i)
|
||||||
|
last_token = token
|
||||||
|
last_i = i
|
||||||
|
ans.append(span)
|
||||||
|
|
||||||
|
if last_token is not None:
|
||||||
|
assert last_i is not None, (last_i, last_token)
|
||||||
|
span = TokenSpan(token=last_token, start=last_i, end=len(alignment))
|
||||||
|
# Note for the last token, its end is larger than len(alignment)-1
|
||||||
|
ans.append(span)
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_tokens():
|
||||||
|
data_list = [
|
||||||
|
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
|
||||||
|
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
|
||||||
|
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
|
||||||
|
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
|
||||||
|
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
|
||||||
|
[0, 1, 2, 3, 0],
|
||||||
|
[1, 2, 3, 0],
|
||||||
|
[0, 1, 2, 3],
|
||||||
|
[1, 2, 3],
|
||||||
|
]
|
||||||
|
|
||||||
|
for data in data_list:
|
||||||
|
span1 = merge_tokens(data)
|
||||||
|
span2 = inefficient_merge_tokens(data)
|
||||||
|
assert span1 == span2, (data, span1, span2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_merge_tokens()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
52
icefall/ctc/utils.py
Normal file
52
icefall/ctc/utils.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenSpan:
|
||||||
|
# ID of the token
|
||||||
|
token: int
|
||||||
|
|
||||||
|
# Start frame of this token in the output log_prob
|
||||||
|
start: int
|
||||||
|
|
||||||
|
# End frame of this token in the output log_prob
|
||||||
|
end: int
|
||||||
|
|
||||||
|
|
||||||
|
# See also
|
||||||
|
# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96
|
||||||
|
# We use torchaudio as a reference while implementing this function
|
||||||
|
def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
|
||||||
|
"""Compute start and end frames of each token from the given alignment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignment:
|
||||||
|
A list of token IDs.
|
||||||
|
blank_id:
|
||||||
|
ID of the blank.
|
||||||
|
Returns:
|
||||||
|
Return a list of TokenSpan.
|
||||||
|
"""
|
||||||
|
alignment_tensor = torch.tensor(alignment, dtype=torch.int32)
|
||||||
|
|
||||||
|
diff = torch.diff(
|
||||||
|
alignment_tensor,
|
||||||
|
prepend=torch.tensor([-1]),
|
||||||
|
append=torch.tensor([-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist()
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]):
|
||||||
|
token = alignment[start]
|
||||||
|
if token == blank:
|
||||||
|
continue
|
||||||
|
span = TokenSpan(token=token, start=start, end=end)
|
||||||
|
ans.append(span)
|
||||||
|
return ans
|
Loading…
x
Reference in New Issue
Block a user