mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
88 lines
2.2 KiB
Python
Executable File
88 lines
2.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
|
|
from typing import List
|
|
|
|
from utils import TokenSpan, merge_tokens
|
|
|
|
|
|
def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
|
|
"""Compute start and end frames of each token from the given alignment.
|
|
|
|
Args:
|
|
alignment:
|
|
A list of token IDs.
|
|
blank_id:
|
|
ID of the blank.
|
|
Returns:
|
|
Return a list of TokenSpan.
|
|
"""
|
|
ans = []
|
|
last_token = None
|
|
last_i = None
|
|
|
|
# import pdb
|
|
|
|
# pdb.set_trace()
|
|
for i, token in enumerate(alignment):
|
|
if token == blank:
|
|
if last_token is None or last_token == token:
|
|
continue
|
|
|
|
# end of the last token
|
|
span = TokenSpan(token=last_token, start=last_i, end=i)
|
|
ans.append(span)
|
|
last_token = None
|
|
last_i = None
|
|
continue
|
|
|
|
# The current token is not a blank
|
|
if last_token is None or last_token == blank:
|
|
last_token = token
|
|
last_i = i
|
|
continue
|
|
|
|
if last_token == token:
|
|
continue
|
|
|
|
# end of the last token and start of the current token
|
|
span = TokenSpan(token=last_token, start=last_i, end=i)
|
|
last_token = token
|
|
last_i = i
|
|
ans.append(span)
|
|
|
|
if last_token is not None:
|
|
assert last_i is not None, (last_i, last_token)
|
|
span = TokenSpan(token=last_token, start=last_i, end=len(alignment))
|
|
# Note for the last token, its end is larger than len(alignment)-1
|
|
ans.append(span)
|
|
|
|
return ans
|
|
|
|
|
|
def test_merge_tokens():
|
|
data_list = [
|
|
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
|
|
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
|
|
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
|
|
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
|
|
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
|
|
[0, 1, 2, 3, 0],
|
|
[1, 2, 3, 0],
|
|
[0, 1, 2, 3],
|
|
[1, 2, 3],
|
|
]
|
|
|
|
for data in data_list:
|
|
span1 = merge_tokens(data)
|
|
span2 = inefficient_merge_tokens(data)
|
|
assert span1 == span2, (data, span1, span2)
|
|
|
|
|
|
def main():
|
|
test_merge_tokens()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|