Convert word IDs in a transcript to token IDs

This commit is contained in:
Fangjun Kuang 2021-09-10 21:03:15 +08:00
parent 5390ced2d1
commit 78e1fdc994
7 changed files with 193 additions and 126 deletions

View File

@ -27,7 +27,6 @@ from icefall.decode import (
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -417,11 +416,6 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
device=device,
)
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
) )

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import gc
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -15,7 +14,6 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from tdnn_lstm_ctc.model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter

View File

@ -84,6 +84,68 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
f.write(f"{word} {' '.join(tokens)}\n") f.write(f"{word} {' '.join(tokens)}\n")
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert it to a ragged tensor.
Caution:
We assume that each word has a unique pronunciation.
Args:
filename:
Filename of the lexicon. It has a format that can be read
by :func:`read_lexicon`.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word_id][token_id]
"""
disambig_id = word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies a position
#
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(disambig_id):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
# row_splits=row_splits, cached_tot_size=cached_tot_size
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
class Lexicon(object): class Lexicon(object):
"""Phone based lexicon.""" """Phone based lexicon."""
@ -119,7 +181,7 @@ class Lexicon(object):
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
# We save L_inv instead of L because it will be used to intersect with # We save L_inv instead of L because it will be used to intersect with
# transcript, both of whose labels are word IDs. # transcript FSAs, both of whose labels are word IDs.
self.L_inv = L_inv self.L_inv = L_inv
self.disambig_pattern = disambig_pattern self.disambig_pattern = disambig_pattern
@ -142,70 +204,61 @@ class Lexicon(object):
return ans return ans
class BpeLexicon(Lexicon): class UniqLexicon(Lexicon):
def __init__( def __init__(
self, self,
lang_dir: Path, lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
disambig_pattern: str = re.compile(r"^#\d+$"), disambig_pattern: str = re.compile(r"^#\d+$"),
): ):
""" """
Refer to the help information in Lexicon.__init__. Refer to the help information in Lexicon.__init__.
uniq_filename: It is assumed to be inside the given `lang_dir`.
Each word in the lexicon is assumed to have a unique pronunciation.
""" """
lang_dir = Path(lang_dir)
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern) super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
self.ragged_lexicon = self.convert_lexicon_to_ragged( self.ragged_lexicon = convert_lexicon_to_ragged(
lang_dir / "lexicon.txt" filename=lang_dir / uniq_filename,
word_table=self.word_table,
token_table=self.token_table,
) )
# TODO: should we move it to a certain device ?
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: def texts_to_token_ids(
"""Read a BPE lexicon from file and convert it to a self, texts: List[str], oov: str = "<UNK>"
k2 ragged tensor. ) -> k2.RaggedTensor:
"""
Args: Args:
filename: texts:
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt A list of transcripts. Each transcript contains space(s)
separated words. An example texts is::
['HELLO k2', 'HELLO icefall']
oov:
The OOV word. If a word in `texts` is not in the lexicon, it is
replaced with `oov`.
Returns: Returns:
A k2 ragged tensor with two axes [word_id] Return a ragged int tensor with 2 axes [utterance][token_id]
""" """
disambig_id = self.word_table["#0"] oov_id = self.word_table[oov]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies on position word_ids_list = []
# for text in texts:
row_splits = [0] word_ids = []
token_ids = [] for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
return self.ragged_lexicon.index(ragged_indexes, remove_axis=True)
lexicon = read_lexicon(filename) def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
lexicon = dict(lexicon) """Convert a list of words to a ragged tensor containing token IDs."""
for i in range(disambig_id):
w = self.word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
pieces = lexicon[w]
piece_ids = [self.token_table[k] for k in pieces]
row_splits.append(row_splits[-1] + len(piece_ids))
token_ids.extend(piece_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=cached_tot_size
)
values = torch.tensor(token_ids, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained
word piece IDs.
"""
word_ids = [self.word_table[w] for w in words] word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32) word_ids = torch.tensor(word_ids, dtype=torch.int32)

View File

@ -1,17 +1,18 @@
from typing import Iterable, List, Tuple, Union
import logging import logging
from pathlib import Path
from typing import Iterable, List, Tuple, Union
import k2 import k2
import torch import torch
from pathlib import Path
from icefall.lexicon import Lexicon from icefall.lexicon import UniqLexicon
class MmiTrainingGraphCompiler(object): class MmiTrainingGraphCompiler(object):
def __init__( def __init__(
self, self,
lang_dir: Path, lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
device: Union[str, torch.device] = "cpu", device: Union[str, torch.device] = "cpu",
oov: str = "<UNK>", oov: str = "<UNK>",
): ):
@ -27,6 +28,9 @@ class MmiTrainingGraphCompiler(object):
The above files are generated by the script `prepare.sh`. You The above files are generated by the script `prepare.sh`. You
should have run it before running the training code. should have run it before running the training code.
uniq_filename:
File name to the lexicon in which every word has exactly one
pronunciation. We assume this file is inside the given `lang_dir`.
device: device:
It indicates CPU or CUDA. It indicates CPU or CUDA.
@ -35,7 +39,7 @@ class MmiTrainingGraphCompiler(object):
does not exist in the lexicon, it is replaced with `oov`. does not exist in the lexicon, it is replaced with `oov`.
""" """
self.lang_dir = Path(lang_dir) self.lang_dir = Path(lang_dir)
self.lexicon = Lexicon(lang_dir) self.lexicon = UniqLexicon(lang_dir)
self.device = torch.device(device) self.device = torch.device(device)
self.L_inv = self.lexicon.L_inv.to(self.device) self.L_inv = self.lexicon.L_inv.to(self.device)
@ -187,3 +191,17 @@ class MmiTrainingGraphCompiler(object):
).invert_() ).invert_()
transcript_fsa = k2.arc_sort(transcript_fsa) transcript_fsa = k2.arc_sort(transcript_fsa)
return transcript_fsa return transcript_fsa
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of piece IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of token IDs.
"""
return self.lexicon.texts_to_token_ids(texts).tolist()

View File

@ -19,20 +19,21 @@
from pathlib import Path from pathlib import Path
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon from icefall.lexicon import UniqLexicon
ICEFALL_DIR = Path(__file__).resolve().parent.parent
def test(): def test():
lang_dir = Path("data/lang/bpe") lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
if not lang_dir.is_dir(): if not lang_dir.is_dir():
return return
# TODO: generate data for testing
compiler = BpeCtcTrainingGraphCompiler(lang_dir) compiler = BpeCtcTrainingGraphCompiler(lang_dir)
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"]) ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
compiler.compile(ids) compiler.compile(ids)
lexicon = BpeLexicon(lang_dir) lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
ids0 = lexicon.words_to_piece_ids(["HELLO"]) ids0 = lexicon.words_to_piece_ids(["HELLO"])
assert ids[0] == ids0.values().tolist() assert ids[0] == ids0.values().tolist()

124
test/test_lexicon.py Normal file → Executable file
View File

@ -14,80 +14,84 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_lexicon.py
(2) cd icefall; ./test/test_lexicon.py
"""
import os
import shutil
import sys
from pathlib import Path from pathlib import Path
import k2 import k2
import pytest
import torch
from icefall.lexicon import BpeLexicon, Lexicon from icefall.lexicon import UniqLexicon
TMP_DIR = "/tmp/icefall-test-lexicon"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
@pytest.fixture def generate_test_data():
def lang_dir(tmp_path): # if Path(TMP_DIR).exists():
phone2id = """ # return
<eps> 0 Path(TMP_DIR).mkdir(exist_ok=True)
a 1 lexicon = """
b 2 <UNK> SPN
f 3 cat c a t
o 4 at a t
r 5 at a a t
z 6 ac a c
SPN 7 ac a c c
#0 8 """
""" lexicon_filename = Path(TMP_DIR) / "lexicon.txt"
word2id = """ with open(lexicon_filename, "w") as f:
<eps> 0 for line in lexicon.strip().split("\n"):
foo 1 f.write(f"{line}\n")
bar 2
baz 3
<UNK> 4
#0 5
"""
L = k2.Fsa.from_str( os.system(
""" f"""
0 0 7 4 0 cd {ICEFALL_DIR}/egs/librispeech/ASR
0 7 -1 -1 0
0 1 3 1 0 ./local/generate_unique_lexicon.py --lang-dir {TMP_DIR}
0 3 2 2 0 ./local/prepare_lang.py --lang-dir {TMP_DIR}
0 5 2 3 0 """
1 2 4 0 0
2 0 4 0 0
3 4 1 0 0
4 0 5 0 0
5 6 1 0 0
6 0 6 0 0
7
""",
num_aux_labels=1,
) )
with open(tmp_path / "tokens.txt", "w") as f:
f.write(phone2id)
with open(tmp_path / "words.txt", "w") as f:
f.write(word2id)
torch.save(L.as_dict(), tmp_path / "L.pt") def delete_test_data():
shutil.rmtree(TMP_DIR)
return tmp_path
def test_lexicon(lang_dir): def uniq_lexicon_test():
lexicon = Lexicon(lang_dir) lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="uniq_lexicon.txt")
assert lexicon.tokens == list(range(1, 8))
texts = ["cat cat", "at ac", "ca at cat"]
token_ids = lexicon.texts_to_token_ids(texts)
#
# c a t c a t a t a 3 SPN a t c a t
expected_ids = [[3, 2, 4, 3, 2, 4], [2, 4, 2, 3], [1, 2, 4, 3, 2, 4]]
expected_ids = k2.RaggedTensor(expected_ids)
assert token_ids == expected_ids
def test_bpe_lexicon(): def test_main():
lang_dir = Path("data/lang/bpe") generate_test_data()
if not lang_dir.is_dir():
return
# TODO: Generate test data for BpeLexicon
lexicon = BpeLexicon(lang_dir) uniq_lexicon_test()
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
ids = lexicon.words_to_piece_ids(words) if USING_PYTEST:
print(ids) delete_test_data()
print([lexicon.token_table[i] for i in ids.values().tolist()])
def main():
test_main()
if __name__ == "__main__" and not USING_PYTEST:
main()

View File

@ -22,10 +22,10 @@ You can run this file in one of the two ways:
(2) cd icefall; ./test/test_mmi_graph_compiler.py (2) cd icefall; ./test/test_mmi_graph_compiler.py
""" """
import copy
import os import os
import shutil import shutil
import sys import sys
import copy
from pathlib import Path from pathlib import Path
import k2 import k2
@ -35,7 +35,6 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler" TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
USING_PYTEST = "pytest" in sys.modules USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent ICEFALL_DIR = Path(__file__).resolve().parent.parent
print(ICEFALL_DIR)
def generate_test_data(): def generate_test_data():