From 78e1fdc9944b3df2d8566111342060fb94417159 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 10 Sep 2021 21:03:15 +0800 Subject: [PATCH] Convert word IDs in a transcript to token IDs --- .../ASR/conformer_mmi_phone/decode.py | 6 - .../ASR/conformer_mmi_phone/train.py | 2 - icefall/lexicon.py | 149 ++++++++++++------ icefall/mmi_graph_compiler.py | 26 ++- test/test_bpe_graph_compiler.py | 9 +- test/test_lexicon.py | 124 ++++++++------- test/test_mmi_graph_compiler.py | 3 +- 7 files changed, 193 insertions(+), 126 deletions(-) mode change 100644 => 100755 test/test_lexicon.py diff --git a/egs/librispeech/ASR/conformer_mmi_phone/decode.py b/egs/librispeech/ASR/conformer_mmi_phone/decode.py index e8b9537a4..7e9c8f78e 100755 --- a/egs/librispeech/ASR/conformer_mmi_phone/decode.py +++ b/egs/librispeech/ASR/conformer_mmi_phone/decode.py @@ -27,7 +27,6 @@ from icefall.decode import ( rescore_with_whole_lattice, ) from icefall.lexicon import Lexicon -from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.utils import ( AttributeDict, get_texts, @@ -417,11 +416,6 @@ def main(): logging.info(f"device: {device}") - graph_compiler = MmiTrainingGraphCompiler( - params.lang_dir, - device=device, - ) - HLG = k2.Fsa.from_dict( torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") ) diff --git a/egs/librispeech/ASR/conformer_mmi_phone/train.py b/egs/librispeech/ASR/conformer_mmi_phone/train.py index 9b0a2d126..402c4a2bb 100755 --- a/egs/librispeech/ASR/conformer_mmi_phone/train.py +++ b/egs/librispeech/ASR/conformer_mmi_phone/train.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import argparse -import gc import logging from pathlib import Path from shutil import copyfile @@ -15,7 +14,6 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed -from tdnn_lstm_ctc.model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 1e1858801..b7f8472d4 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -84,6 +84,68 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None: f.write(f"{word} {' '.join(tokens)}\n") +def convert_lexicon_to_ragged( + filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable +) -> k2.RaggedTensor: + """Read a lexicon and convert it to a ragged tensor. + + Caution: + We assume that each word has a unique pronunciation. + + Args: + filename: + Filename of the lexicon. It has a format that can be read + by :func:`read_lexicon`. + word_table: + The word symbol table. + token_table: + The token symbol table. + Returns: + A k2 ragged tensor with two axes [word_id][token_id] + """ + disambig_id = word_table["#0"] + # We reuse the same words.txt from the phone based lexicon + # so that we can share the same G.fst. Here, we have to + # exclude some words present only in the phone based lexicon. + excluded_words = ["", "!SIL", ""] + + # epsilon is not a word, but it occupies a position + # + row_splits = [0] + token_ids_list = [] + + lexicon_tmp = read_lexicon(filename) + lexicon = dict(lexicon_tmp) + if len(lexicon_tmp) != len(lexicon): + raise RuntimeError( + "It's assumed that each word has a unique pronunciation" + ) + + for i in range(disambig_id): + w = word_table[i] + if w in excluded_words: + row_splits.append(row_splits[-1]) + continue + tokens = lexicon[w] + token_ids = [token_table[k] for k in tokens] + + row_splits.append(row_splits[-1] + len(token_ids)) + token_ids_list.extend(token_ids) + + cached_tot_size = row_splits[-1] + row_splits = torch.tensor(row_splits, dtype=torch.int32) + + shape = k2.ragged.create_ragged_shape2( + # row_splits=row_splits, cached_tot_size=cached_tot_size + row_splits, + None, + cached_tot_size, + ) + values = torch.tensor(token_ids_list, dtype=torch.int32) + + return k2.RaggedTensor(shape, values) + + class Lexicon(object): """Phone based lexicon.""" @@ -119,7 +181,7 @@ class Lexicon(object): torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") # We save L_inv instead of L because it will be used to intersect with - # transcript, both of whose labels are word IDs. + # transcript FSAs, both of whose labels are word IDs. self.L_inv = L_inv self.disambig_pattern = disambig_pattern @@ -142,70 +204,61 @@ class Lexicon(object): return ans -class BpeLexicon(Lexicon): +class UniqLexicon(Lexicon): def __init__( self, lang_dir: Path, + uniq_filename: str = "uniq_lexicon.txt", disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Refer to the help information in Lexicon.__init__. + + uniq_filename: It is assumed to be inside the given `lang_dir`. + Each word in the lexicon is assumed to have a unique pronunciation. """ + lang_dir = Path(lang_dir) super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern) - self.ragged_lexicon = self.convert_lexicon_to_ragged( - lang_dir / "lexicon.txt" + self.ragged_lexicon = convert_lexicon_to_ragged( + filename=lang_dir / uniq_filename, + word_table=self.word_table, + token_table=self.token_table, ) + # TODO: should we move it to a certain device ? - def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: - """Read a BPE lexicon from file and convert it to a - k2 ragged tensor. - + def texts_to_token_ids( + self, texts: List[str], oov: str = "" + ) -> k2.RaggedTensor: + """ Args: - filename: - Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt + texts: + A list of transcripts. Each transcript contains space(s) + separated words. An example texts is:: + + ['HELLO k2', 'HELLO icefall'] + oov: + The OOV word. If a word in `texts` is not in the lexicon, it is + replaced with `oov`. Returns: - A k2 ragged tensor with two axes [word_id] + Return a ragged int tensor with 2 axes [utterance][token_id] """ - disambig_id = self.word_table["#0"] - # We reuse the same words.txt from the phone based lexicon - # so that we can share the same G.fst. Here, we have to - # exclude some words present only in the phone based lexicon. - excluded_words = ["", "!SIL", ""] + oov_id = self.word_table[oov] - # epsilon is not a word, but it occupies on position - # - row_splits = [0] - token_ids = [] + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(oov_id) + word_ids_list.append(word_ids) + ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32) + return self.ragged_lexicon.index(ragged_indexes, remove_axis=True) - lexicon = read_lexicon(filename) - lexicon = dict(lexicon) - - for i in range(disambig_id): - w = self.word_table[i] - if w in excluded_words: - row_splits.append(row_splits[-1]) - continue - pieces = lexicon[w] - piece_ids = [self.token_table[k] for k in pieces] - - row_splits.append(row_splits[-1] + len(piece_ids)) - token_ids.extend(piece_ids) - - cached_tot_size = row_splits[-1] - row_splits = torch.tensor(row_splits, dtype=torch.int32) - - shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=cached_tot_size - ) - values = torch.tensor(token_ids, dtype=torch.int32) - - return k2.RaggedTensor(shape, values) - - def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor: - """Convert a list of words to a ragged tensor contained - word piece IDs. - """ + def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor: + """Convert a list of words to a ragged tensor containing token IDs.""" word_ids = [self.word_table[w] for w in words] word_ids = torch.tensor(word_ids, dtype=torch.int32) diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py index 89f725aa7..19f4c1770 100644 --- a/icefall/mmi_graph_compiler.py +++ b/icefall/mmi_graph_compiler.py @@ -1,17 +1,18 @@ -from typing import Iterable, List, Tuple, Union import logging +from pathlib import Path +from typing import Iterable, List, Tuple, Union import k2 import torch -from pathlib import Path -from icefall.lexicon import Lexicon +from icefall.lexicon import UniqLexicon class MmiTrainingGraphCompiler(object): def __init__( self, lang_dir: Path, + uniq_filename: str = "uniq_lexicon.txt", device: Union[str, torch.device] = "cpu", oov: str = "", ): @@ -27,6 +28,9 @@ class MmiTrainingGraphCompiler(object): The above files are generated by the script `prepare.sh`. You should have run it before running the training code. + uniq_filename: + File name to the lexicon in which every word has exactly one + pronunciation. We assume this file is inside the given `lang_dir`. device: It indicates CPU or CUDA. @@ -35,7 +39,7 @@ class MmiTrainingGraphCompiler(object): does not exist in the lexicon, it is replaced with `oov`. """ self.lang_dir = Path(lang_dir) - self.lexicon = Lexicon(lang_dir) + self.lexicon = UniqLexicon(lang_dir) self.device = torch.device(device) self.L_inv = self.lexicon.L_inv.to(self.device) @@ -187,3 +191,17 @@ class MmiTrainingGraphCompiler(object): ).invert_() transcript_fsa = k2.arc_sort(transcript_fsa) return transcript_fsa + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of piece IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of token IDs. + """ + return self.lexicon.texts_to_token_ids(texts).tolist() diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py index e58c4f1c6..6c9073c4c 100755 --- a/test/test_bpe_graph_compiler.py +++ b/test/test_bpe_graph_compiler.py @@ -19,20 +19,21 @@ from pathlib import Path from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.lexicon import BpeLexicon +from icefall.lexicon import UniqLexicon + +ICEFALL_DIR = Path(__file__).resolve().parent.parent def test(): - lang_dir = Path("data/lang/bpe") + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe" if not lang_dir.is_dir(): return - # TODO: generate data for testing compiler = BpeCtcTrainingGraphCompiler(lang_dir) ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"]) compiler.compile(ids) - lexicon = BpeLexicon(lang_dir) + lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt") ids0 = lexicon.words_to_piece_ids(["HELLO"]) assert ids[0] == ids0.values().tolist() diff --git a/test/test_lexicon.py b/test/test_lexicon.py old mode 100644 new mode 100755 index 6801b3a89..f315972a2 --- a/test/test_lexicon.py +++ b/test/test_lexicon.py @@ -14,80 +14,84 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +You can run this file in one of the two ways: + + (1) cd icefall; pytest test/test_lexicon.py + (2) cd icefall; ./test/test_lexicon.py +""" +import os +import shutil +import sys from pathlib import Path import k2 -import pytest -import torch -from icefall.lexicon import BpeLexicon, Lexicon +from icefall.lexicon import UniqLexicon + +TMP_DIR = "/tmp/icefall-test-lexicon" +USING_PYTEST = "pytest" in sys.modules +ICEFALL_DIR = Path(__file__).resolve().parent.parent -@pytest.fixture -def lang_dir(tmp_path): - phone2id = """ - 0 - a 1 - b 2 - f 3 - o 4 - r 5 - z 6 - SPN 7 - #0 8 - """ - word2id = """ - 0 - foo 1 - bar 2 - baz 3 - 4 - #0 5 - """ +def generate_test_data(): + # if Path(TMP_DIR).exists(): + # return + Path(TMP_DIR).mkdir(exist_ok=True) + lexicon = """ + SPN +cat c a t +at a t +at a a t +ac a c +ac a c c +""" + lexicon_filename = Path(TMP_DIR) / "lexicon.txt" + with open(lexicon_filename, "w") as f: + for line in lexicon.strip().split("\n"): + f.write(f"{line}\n") - L = k2.Fsa.from_str( - """ - 0 0 7 4 0 - 0 7 -1 -1 0 - 0 1 3 1 0 - 0 3 2 2 0 - 0 5 2 3 0 - 1 2 4 0 0 - 2 0 4 0 0 - 3 4 1 0 0 - 4 0 5 0 0 - 5 6 1 0 0 - 6 0 6 0 0 - 7 - """, - num_aux_labels=1, + os.system( + f""" +cd {ICEFALL_DIR}/egs/librispeech/ASR + +./local/generate_unique_lexicon.py --lang-dir {TMP_DIR} +./local/prepare_lang.py --lang-dir {TMP_DIR} +""" ) - with open(tmp_path / "tokens.txt", "w") as f: - f.write(phone2id) - with open(tmp_path / "words.txt", "w") as f: - f.write(word2id) - torch.save(L.as_dict(), tmp_path / "L.pt") - - return tmp_path +def delete_test_data(): + shutil.rmtree(TMP_DIR) -def test_lexicon(lang_dir): - lexicon = Lexicon(lang_dir) - assert lexicon.tokens == list(range(1, 8)) +def uniq_lexicon_test(): + lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="uniq_lexicon.txt") + + texts = ["cat cat", "at ac", "ca at cat"] + token_ids = lexicon.texts_to_token_ids(texts) + # + # c a t c a t a t a 3 SPN a t c a t + expected_ids = [[3, 2, 4, 3, 2, 4], [2, 4, 2, 3], [1, 2, 4, 3, 2, 4]] + expected_ids = k2.RaggedTensor(expected_ids) + + assert token_ids == expected_ids -def test_bpe_lexicon(): - lang_dir = Path("data/lang/bpe") - if not lang_dir.is_dir(): - return - # TODO: Generate test data for BpeLexicon +def test_main(): + generate_test_data() - lexicon = BpeLexicon(lang_dir) - words = ["", "HELLO", "ZZZZ", "WORLD"] - ids = lexicon.words_to_piece_ids(words) - print(ids) - print([lexicon.token_table[i] for i in ids.values().tolist()]) + uniq_lexicon_test() + + if USING_PYTEST: + delete_test_data() + + +def main(): + test_main() + + +if __name__ == "__main__" and not USING_PYTEST: + main() diff --git a/test/test_mmi_graph_compiler.py b/test/test_mmi_graph_compiler.py index 336884006..80a1d9722 100755 --- a/test/test_mmi_graph_compiler.py +++ b/test/test_mmi_graph_compiler.py @@ -22,10 +22,10 @@ You can run this file in one of the two ways: (2) cd icefall; ./test/test_mmi_graph_compiler.py """ +import copy import os import shutil import sys -import copy from pathlib import Path import k2 @@ -35,7 +35,6 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler" USING_PYTEST = "pytest" in sys.modules ICEFALL_DIR = Path(__file__).resolve().parent.parent -print(ICEFALL_DIR) def generate_test_data():