Convert word IDs in a transcript to token IDs

This commit is contained in:
Fangjun Kuang 2021-09-10 21:03:15 +08:00
parent 5390ced2d1
commit 78e1fdc994
7 changed files with 193 additions and 126 deletions

View File

@ -27,7 +27,6 @@ from icefall.decode import (
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
get_texts,
@ -417,11 +416,6 @@ def main():
logging.info(f"device: {device}")
graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
device=device,
)
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3
import argparse
import gc
import logging
from pathlib import Path
from shutil import copyfile
@ -15,7 +14,6 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from tdnn_lstm_ctc.model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter

View File

@ -84,6 +84,68 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
f.write(f"{word} {' '.join(tokens)}\n")
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert it to a ragged tensor.
Caution:
We assume that each word has a unique pronunciation.
Args:
filename:
Filename of the lexicon. It has a format that can be read
by :func:`read_lexicon`.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word_id][token_id]
"""
disambig_id = word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies a position
#
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(disambig_id):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
# row_splits=row_splits, cached_tot_size=cached_tot_size
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
class Lexicon(object):
"""Phone based lexicon."""
@ -119,7 +181,7 @@ class Lexicon(object):
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
# We save L_inv instead of L because it will be used to intersect with
# transcript, both of whose labels are word IDs.
# transcript FSAs, both of whose labels are word IDs.
self.L_inv = L_inv
self.disambig_pattern = disambig_pattern
@ -142,70 +204,61 @@ class Lexicon(object):
return ans
class BpeLexicon(Lexicon):
class UniqLexicon(Lexicon):
def __init__(
self,
lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Refer to the help information in Lexicon.__init__.
uniq_filename: It is assumed to be inside the given `lang_dir`.
Each word in the lexicon is assumed to have a unique pronunciation.
"""
lang_dir = Path(lang_dir)
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
self.ragged_lexicon = self.convert_lexicon_to_ragged(
lang_dir / "lexicon.txt"
self.ragged_lexicon = convert_lexicon_to_ragged(
filename=lang_dir / uniq_filename,
word_table=self.word_table,
token_table=self.token_table,
)
# TODO: should we move it to a certain device ?
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
"""Read a BPE lexicon from file and convert it to a
k2 ragged tensor.
def texts_to_token_ids(
self, texts: List[str], oov: str = "<UNK>"
) -> k2.RaggedTensor:
"""
Args:
filename:
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
texts:
A list of transcripts. Each transcript contains space(s)
separated words. An example texts is::
['HELLO k2', 'HELLO icefall']
oov:
The OOV word. If a word in `texts` is not in the lexicon, it is
replaced with `oov`.
Returns:
A k2 ragged tensor with two axes [word_id]
Return a ragged int tensor with 2 axes [utterance][token_id]
"""
disambig_id = self.word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
oov_id = self.word_table[oov]
# epsilon is not a word, but it occupies on position
#
row_splits = [0]
token_ids = []
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
return self.ragged_lexicon.index(ragged_indexes, remove_axis=True)
lexicon = read_lexicon(filename)
lexicon = dict(lexicon)
for i in range(disambig_id):
w = self.word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
pieces = lexicon[w]
piece_ids = [self.token_table[k] for k in pieces]
row_splits.append(row_splits[-1] + len(piece_ids))
token_ids.extend(piece_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=cached_tot_size
)
values = torch.tensor(token_ids, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained
word piece IDs.
"""
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor containing token IDs."""
word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32)

View File

@ -1,17 +1,18 @@
from typing import Iterable, List, Tuple, Union
import logging
from pathlib import Path
from typing import Iterable, List, Tuple, Union
import k2
import torch
from pathlib import Path
from icefall.lexicon import Lexicon
from icefall.lexicon import UniqLexicon
class MmiTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
device: Union[str, torch.device] = "cpu",
oov: str = "<UNK>",
):
@ -27,6 +28,9 @@ class MmiTrainingGraphCompiler(object):
The above files are generated by the script `prepare.sh`. You
should have run it before running the training code.
uniq_filename:
File name to the lexicon in which every word has exactly one
pronunciation. We assume this file is inside the given `lang_dir`.
device:
It indicates CPU or CUDA.
@ -35,7 +39,7 @@ class MmiTrainingGraphCompiler(object):
does not exist in the lexicon, it is replaced with `oov`.
"""
self.lang_dir = Path(lang_dir)
self.lexicon = Lexicon(lang_dir)
self.lexicon = UniqLexicon(lang_dir)
self.device = torch.device(device)
self.L_inv = self.lexicon.L_inv.to(self.device)
@ -187,3 +191,17 @@ class MmiTrainingGraphCompiler(object):
).invert_()
transcript_fsa = k2.arc_sort(transcript_fsa)
return transcript_fsa
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of piece IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of token IDs.
"""
return self.lexicon.texts_to_token_ids(texts).tolist()

View File

@ -19,20 +19,21 @@
from pathlib import Path
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon
from icefall.lexicon import UniqLexicon
ICEFALL_DIR = Path(__file__).resolve().parent.parent
def test():
lang_dir = Path("data/lang/bpe")
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
if not lang_dir.is_dir():
return
# TODO: generate data for testing
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
compiler.compile(ids)
lexicon = BpeLexicon(lang_dir)
lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
ids0 = lexicon.words_to_piece_ids(["HELLO"])
assert ids[0] == ids0.values().tolist()

124
test/test_lexicon.py Normal file → Executable file
View File

@ -14,80 +14,84 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_lexicon.py
(2) cd icefall; ./test/test_lexicon.py
"""
import os
import shutil
import sys
from pathlib import Path
import k2
import pytest
import torch
from icefall.lexicon import BpeLexicon, Lexicon
from icefall.lexicon import UniqLexicon
TMP_DIR = "/tmp/icefall-test-lexicon"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
@pytest.fixture
def lang_dir(tmp_path):
phone2id = """
<eps> 0
a 1
b 2
f 3
o 4
r 5
z 6
SPN 7
#0 8
"""
word2id = """
<eps> 0
foo 1
bar 2
baz 3
<UNK> 4
#0 5
"""
def generate_test_data():
# if Path(TMP_DIR).exists():
# return
Path(TMP_DIR).mkdir(exist_ok=True)
lexicon = """
<UNK> SPN
cat c a t
at a t
at a a t
ac a c
ac a c c
"""
lexicon_filename = Path(TMP_DIR) / "lexicon.txt"
with open(lexicon_filename, "w") as f:
for line in lexicon.strip().split("\n"):
f.write(f"{line}\n")
L = k2.Fsa.from_str(
"""
0 0 7 4 0
0 7 -1 -1 0
0 1 3 1 0
0 3 2 2 0
0 5 2 3 0
1 2 4 0 0
2 0 4 0 0
3 4 1 0 0
4 0 5 0 0
5 6 1 0 0
6 0 6 0 0
7
""",
num_aux_labels=1,
os.system(
f"""
cd {ICEFALL_DIR}/egs/librispeech/ASR
./local/generate_unique_lexicon.py --lang-dir {TMP_DIR}
./local/prepare_lang.py --lang-dir {TMP_DIR}
"""
)
with open(tmp_path / "tokens.txt", "w") as f:
f.write(phone2id)
with open(tmp_path / "words.txt", "w") as f:
f.write(word2id)
torch.save(L.as_dict(), tmp_path / "L.pt")
return tmp_path
def delete_test_data():
shutil.rmtree(TMP_DIR)
def test_lexicon(lang_dir):
lexicon = Lexicon(lang_dir)
assert lexicon.tokens == list(range(1, 8))
def uniq_lexicon_test():
lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="uniq_lexicon.txt")
texts = ["cat cat", "at ac", "ca at cat"]
token_ids = lexicon.texts_to_token_ids(texts)
#
# c a t c a t a t a 3 SPN a t c a t
expected_ids = [[3, 2, 4, 3, 2, 4], [2, 4, 2, 3], [1, 2, 4, 3, 2, 4]]
expected_ids = k2.RaggedTensor(expected_ids)
assert token_ids == expected_ids
def test_bpe_lexicon():
lang_dir = Path("data/lang/bpe")
if not lang_dir.is_dir():
return
# TODO: Generate test data for BpeLexicon
def test_main():
generate_test_data()
lexicon = BpeLexicon(lang_dir)
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
ids = lexicon.words_to_piece_ids(words)
print(ids)
print([lexicon.token_table[i] for i in ids.values().tolist()])
uniq_lexicon_test()
if USING_PYTEST:
delete_test_data()
def main():
test_main()
if __name__ == "__main__" and not USING_PYTEST:
main()

View File

@ -22,10 +22,10 @@ You can run this file in one of the two ways:
(2) cd icefall; ./test/test_mmi_graph_compiler.py
"""
import copy
import os
import shutil
import sys
import copy
from pathlib import Path
import k2
@ -35,7 +35,6 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
print(ICEFALL_DIR)
def generate_test_data():