diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py index b546b31af..eb1ec47d1 100644 --- a/icefall/ctc/__init__.py +++ b/icefall/ctc/__init__.py @@ -4,3 +4,4 @@ from .prepare_lang import ( make_lexicon_fst_with_silence, ) from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo +from .utils import merge_tokens diff --git a/icefall/ctc/test_utils.py b/icefall/ctc/test_utils.py new file mode 100755 index 000000000..6fa883dfb --- /dev/null +++ b/icefall/ctc/test_utils.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +from typing import List + +from utils import TokenSpan, merge_tokens + + +def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: + """Compute start and end frames of each token from the given alignment. + + Args: + alignment: + A list of token IDs. + blank_id: + ID of the blank. + Returns: + Return a list of TokenSpan. + """ + ans = [] + last_token = None + last_i = None + + # import pdb + + # pdb.set_trace() + for i, token in enumerate(alignment): + if token == blank: + if last_token is None or last_token == token: + continue + + # end of the last token + span = TokenSpan(token=last_token, start=last_i, end=i) + ans.append(span) + last_token = None + last_i = None + continue + + # The current token is not a blank + if last_token is None or last_token == blank: + last_token = token + last_i = i + continue + + if last_token == token: + continue + + # end of the last token and start of the current token + span = TokenSpan(token=last_token, start=last_i, end=i) + last_token = token + last_i = i + ans.append(span) + + if last_token is not None: + assert last_i is not None, (last_i, last_token) + span = TokenSpan(token=last_token, start=last_i, end=len(alignment)) + # Note for the last token, its end is larger than len(alignment)-1 + ans.append(span) + + return ans + + +def test_merge_tokens(): + data_list = [ + # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 + [0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], + [0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], + [1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], + [1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], + [0, 1, 2, 3, 0], + [1, 2, 3, 0], + [0, 1, 2, 3], + [1, 2, 3], + ] + + for data in data_list: + span1 = merge_tokens(data) + span2 = inefficient_merge_tokens(data) + assert span1 == span2, (data, span1, span2) + + +def main(): + test_merge_tokens() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/utils.py b/icefall/ctc/utils.py new file mode 100644 index 000000000..ad49b5ffd --- /dev/null +++ b/icefall/ctc/utils.py @@ -0,0 +1,52 @@ +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +from dataclasses import dataclass +from typing import List + +import torch + + +@dataclass +class TokenSpan: + # ID of the token + token: int + + # Start frame of this token in the output log_prob + start: int + + # End frame of this token in the output log_prob + end: int + + +# See also +# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96 +# We use torchaudio as a reference while implementing this function +def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: + """Compute start and end frames of each token from the given alignment. + + Args: + alignment: + A list of token IDs. + blank_id: + ID of the blank. + Returns: + Return a list of TokenSpan. + """ + alignment_tensor = torch.tensor(alignment, dtype=torch.int32) + + diff = torch.diff( + alignment_tensor, + prepend=torch.tensor([-1]), + append=torch.tensor([-1]), + ) + + non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist() + + ans = [] + for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]): + token = alignment[start] + if token == blank: + continue + span = TokenSpan(token=token, start=start, end=end) + ans.append(span) + return ans