mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-19 23:06:21 +00:00
Convert word IDs in a transcript to token IDs
This commit is contained in:
parent
5390ced2d1
commit
78e1fdc994
@ -27,7 +27,6 @@ from icefall.decode import (
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -417,11 +416,6 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = MmiTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
)
|
||||
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -15,7 +14,6 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from tdnn_lstm_ctc.model import TdnnLstm
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
@ -84,6 +84,68 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
|
||||
f.write(f"{word} {' '.join(tokens)}\n")
|
||||
|
||||
|
||||
def convert_lexicon_to_ragged(
|
||||
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
|
||||
) -> k2.RaggedTensor:
|
||||
"""Read a lexicon and convert it to a ragged tensor.
|
||||
|
||||
Caution:
|
||||
We assume that each word has a unique pronunciation.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the lexicon. It has a format that can be read
|
||||
by :func:`read_lexicon`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
token_table:
|
||||
The token symbol table.
|
||||
Returns:
|
||||
A k2 ragged tensor with two axes [word_id][token_id]
|
||||
"""
|
||||
disambig_id = word_table["#0"]
|
||||
# We reuse the same words.txt from the phone based lexicon
|
||||
# so that we can share the same G.fst. Here, we have to
|
||||
# exclude some words present only in the phone based lexicon.
|
||||
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
||||
|
||||
# epsilon is not a word, but it occupies a position
|
||||
#
|
||||
row_splits = [0]
|
||||
token_ids_list = []
|
||||
|
||||
lexicon_tmp = read_lexicon(filename)
|
||||
lexicon = dict(lexicon_tmp)
|
||||
if len(lexicon_tmp) != len(lexicon):
|
||||
raise RuntimeError(
|
||||
"It's assumed that each word has a unique pronunciation"
|
||||
)
|
||||
|
||||
for i in range(disambig_id):
|
||||
w = word_table[i]
|
||||
if w in excluded_words:
|
||||
row_splits.append(row_splits[-1])
|
||||
continue
|
||||
tokens = lexicon[w]
|
||||
token_ids = [token_table[k] for k in tokens]
|
||||
|
||||
row_splits.append(row_splits[-1] + len(token_ids))
|
||||
token_ids_list.extend(token_ids)
|
||||
|
||||
cached_tot_size = row_splits[-1]
|
||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
# row_splits=row_splits, cached_tot_size=cached_tot_size
|
||||
row_splits,
|
||||
None,
|
||||
cached_tot_size,
|
||||
)
|
||||
values = torch.tensor(token_ids_list, dtype=torch.int32)
|
||||
|
||||
return k2.RaggedTensor(shape, values)
|
||||
|
||||
|
||||
class Lexicon(object):
|
||||
"""Phone based lexicon."""
|
||||
|
||||
@ -119,7 +181,7 @@ class Lexicon(object):
|
||||
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
|
||||
|
||||
# We save L_inv instead of L because it will be used to intersect with
|
||||
# transcript, both of whose labels are word IDs.
|
||||
# transcript FSAs, both of whose labels are word IDs.
|
||||
self.L_inv = L_inv
|
||||
self.disambig_pattern = disambig_pattern
|
||||
|
||||
@ -142,70 +204,61 @@ class Lexicon(object):
|
||||
return ans
|
||||
|
||||
|
||||
class BpeLexicon(Lexicon):
|
||||
class UniqLexicon(Lexicon):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
uniq_filename: str = "uniq_lexicon.txt",
|
||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
):
|
||||
"""
|
||||
Refer to the help information in Lexicon.__init__.
|
||||
|
||||
uniq_filename: It is assumed to be inside the given `lang_dir`.
|
||||
Each word in the lexicon is assumed to have a unique pronunciation.
|
||||
"""
|
||||
lang_dir = Path(lang_dir)
|
||||
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
|
||||
|
||||
self.ragged_lexicon = self.convert_lexicon_to_ragged(
|
||||
lang_dir / "lexicon.txt"
|
||||
self.ragged_lexicon = convert_lexicon_to_ragged(
|
||||
filename=lang_dir / uniq_filename,
|
||||
word_table=self.word_table,
|
||||
token_table=self.token_table,
|
||||
)
|
||||
# TODO: should we move it to a certain device ?
|
||||
|
||||
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
|
||||
"""Read a BPE lexicon from file and convert it to a
|
||||
k2 ragged tensor.
|
||||
|
||||
def texts_to_token_ids(
|
||||
self, texts: List[str], oov: str = "<UNK>"
|
||||
) -> k2.RaggedTensor:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
|
||||
texts:
|
||||
A list of transcripts. Each transcript contains space(s)
|
||||
separated words. An example texts is::
|
||||
|
||||
['HELLO k2', 'HELLO icefall']
|
||||
oov:
|
||||
The OOV word. If a word in `texts` is not in the lexicon, it is
|
||||
replaced with `oov`.
|
||||
Returns:
|
||||
A k2 ragged tensor with two axes [word_id]
|
||||
Return a ragged int tensor with 2 axes [utterance][token_id]
|
||||
"""
|
||||
disambig_id = self.word_table["#0"]
|
||||
# We reuse the same words.txt from the phone based lexicon
|
||||
# so that we can share the same G.fst. Here, we have to
|
||||
# exclude some words present only in the phone based lexicon.
|
||||
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
||||
oov_id = self.word_table[oov]
|
||||
|
||||
# epsilon is not a word, but it occupies on position
|
||||
#
|
||||
row_splits = [0]
|
||||
token_ids = []
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
word_ids.append(oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
|
||||
return self.ragged_lexicon.index(ragged_indexes, remove_axis=True)
|
||||
|
||||
lexicon = read_lexicon(filename)
|
||||
lexicon = dict(lexicon)
|
||||
|
||||
for i in range(disambig_id):
|
||||
w = self.word_table[i]
|
||||
if w in excluded_words:
|
||||
row_splits.append(row_splits[-1])
|
||||
continue
|
||||
pieces = lexicon[w]
|
||||
piece_ids = [self.token_table[k] for k in pieces]
|
||||
|
||||
row_splits.append(row_splits[-1] + len(piece_ids))
|
||||
token_ids.extend(piece_ids)
|
||||
|
||||
cached_tot_size = row_splits[-1]
|
||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=cached_tot_size
|
||||
)
|
||||
values = torch.tensor(token_ids, dtype=torch.int32)
|
||||
|
||||
return k2.RaggedTensor(shape, values)
|
||||
|
||||
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
|
||||
"""Convert a list of words to a ragged tensor contained
|
||||
word piece IDs.
|
||||
"""
|
||||
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
|
||||
"""Convert a list of words to a ragged tensor containing token IDs."""
|
||||
word_ids = [self.word_table[w] for w in words]
|
||||
word_ids = torch.tensor(word_ids, dtype=torch.int32)
|
||||
|
||||
|
@ -1,17 +1,18 @@
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
|
||||
class MmiTrainingGraphCompiler(object):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
uniq_filename: str = "uniq_lexicon.txt",
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
oov: str = "<UNK>",
|
||||
):
|
||||
@ -27,6 +28,9 @@ class MmiTrainingGraphCompiler(object):
|
||||
|
||||
The above files are generated by the script `prepare.sh`. You
|
||||
should have run it before running the training code.
|
||||
uniq_filename:
|
||||
File name to the lexicon in which every word has exactly one
|
||||
pronunciation. We assume this file is inside the given `lang_dir`.
|
||||
|
||||
device:
|
||||
It indicates CPU or CUDA.
|
||||
@ -35,7 +39,7 @@ class MmiTrainingGraphCompiler(object):
|
||||
does not exist in the lexicon, it is replaced with `oov`.
|
||||
"""
|
||||
self.lang_dir = Path(lang_dir)
|
||||
self.lexicon = Lexicon(lang_dir)
|
||||
self.lexicon = UniqLexicon(lang_dir)
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.L_inv = self.lexicon.L_inv.to(self.device)
|
||||
@ -187,3 +191,17 @@ class MmiTrainingGraphCompiler(object):
|
||||
).invert_()
|
||||
transcript_fsa = k2.arc_sort(transcript_fsa)
|
||||
return transcript_fsa
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of piece IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings. Each string consists of space(s)
|
||||
separated words. An example containing two strings is given below:
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
Returns:
|
||||
Return a list-of-list of token IDs.
|
||||
"""
|
||||
return self.lexicon.texts_to_token_ids(texts).tolist()
|
||||
|
@ -19,20 +19,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.lexicon import BpeLexicon
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def test():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
|
||||
if not lang_dir.is_dir():
|
||||
return
|
||||
# TODO: generate data for testing
|
||||
|
||||
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
|
||||
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
|
||||
compiler.compile(ids)
|
||||
|
||||
lexicon = BpeLexicon(lang_dir)
|
||||
lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
|
||||
ids0 = lexicon.words_to_piece_ids(["HELLO"])
|
||||
assert ids[0] == ids0.values().tolist()
|
||||
|
||||
|
124
test/test_lexicon.py
Normal file → Executable file
124
test/test_lexicon.py
Normal file → Executable file
@ -14,80 +14,84 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
You can run this file in one of the two ways:
|
||||
|
||||
(1) cd icefall; pytest test/test_lexicon.py
|
||||
(2) cd icefall; ./test/test_lexicon.py
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import BpeLexicon, Lexicon
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
TMP_DIR = "/tmp/icefall-test-lexicon"
|
||||
USING_PYTEST = "pytest" in sys.modules
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_dir(tmp_path):
|
||||
phone2id = """
|
||||
<eps> 0
|
||||
a 1
|
||||
b 2
|
||||
f 3
|
||||
o 4
|
||||
r 5
|
||||
z 6
|
||||
SPN 7
|
||||
#0 8
|
||||
"""
|
||||
word2id = """
|
||||
<eps> 0
|
||||
foo 1
|
||||
bar 2
|
||||
baz 3
|
||||
<UNK> 4
|
||||
#0 5
|
||||
"""
|
||||
def generate_test_data():
|
||||
# if Path(TMP_DIR).exists():
|
||||
# return
|
||||
Path(TMP_DIR).mkdir(exist_ok=True)
|
||||
lexicon = """
|
||||
<UNK> SPN
|
||||
cat c a t
|
||||
at a t
|
||||
at a a t
|
||||
ac a c
|
||||
ac a c c
|
||||
"""
|
||||
lexicon_filename = Path(TMP_DIR) / "lexicon.txt"
|
||||
with open(lexicon_filename, "w") as f:
|
||||
for line in lexicon.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
L = k2.Fsa.from_str(
|
||||
"""
|
||||
0 0 7 4 0
|
||||
0 7 -1 -1 0
|
||||
0 1 3 1 0
|
||||
0 3 2 2 0
|
||||
0 5 2 3 0
|
||||
1 2 4 0 0
|
||||
2 0 4 0 0
|
||||
3 4 1 0 0
|
||||
4 0 5 0 0
|
||||
5 6 1 0 0
|
||||
6 0 6 0 0
|
||||
7
|
||||
""",
|
||||
num_aux_labels=1,
|
||||
os.system(
|
||||
f"""
|
||||
cd {ICEFALL_DIR}/egs/librispeech/ASR
|
||||
|
||||
./local/generate_unique_lexicon.py --lang-dir {TMP_DIR}
|
||||
./local/prepare_lang.py --lang-dir {TMP_DIR}
|
||||
"""
|
||||
)
|
||||
|
||||
with open(tmp_path / "tokens.txt", "w") as f:
|
||||
f.write(phone2id)
|
||||
with open(tmp_path / "words.txt", "w") as f:
|
||||
f.write(word2id)
|
||||
|
||||
torch.save(L.as_dict(), tmp_path / "L.pt")
|
||||
|
||||
return tmp_path
|
||||
def delete_test_data():
|
||||
shutil.rmtree(TMP_DIR)
|
||||
|
||||
|
||||
def test_lexicon(lang_dir):
|
||||
lexicon = Lexicon(lang_dir)
|
||||
assert lexicon.tokens == list(range(1, 8))
|
||||
def uniq_lexicon_test():
|
||||
lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="uniq_lexicon.txt")
|
||||
|
||||
texts = ["cat cat", "at ac", "ca at cat"]
|
||||
token_ids = lexicon.texts_to_token_ids(texts)
|
||||
#
|
||||
# c a t c a t a t a 3 SPN a t c a t
|
||||
expected_ids = [[3, 2, 4, 3, 2, 4], [2, 4, 2, 3], [1, 2, 4, 3, 2, 4]]
|
||||
expected_ids = k2.RaggedTensor(expected_ids)
|
||||
|
||||
assert token_ids == expected_ids
|
||||
|
||||
|
||||
def test_bpe_lexicon():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
if not lang_dir.is_dir():
|
||||
return
|
||||
# TODO: Generate test data for BpeLexicon
|
||||
def test_main():
|
||||
generate_test_data()
|
||||
|
||||
lexicon = BpeLexicon(lang_dir)
|
||||
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
|
||||
ids = lexicon.words_to_piece_ids(words)
|
||||
print(ids)
|
||||
print([lexicon.token_table[i] for i in ids.values().tolist()])
|
||||
uniq_lexicon_test()
|
||||
|
||||
if USING_PYTEST:
|
||||
delete_test_data()
|
||||
|
||||
|
||||
def main():
|
||||
test_main()
|
||||
|
||||
|
||||
if __name__ == "__main__" and not USING_PYTEST:
|
||||
main()
|
||||
|
@ -22,10 +22,10 @@ You can run this file in one of the two ways:
|
||||
(2) cd icefall; ./test/test_mmi_graph_compiler.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
@ -35,7 +35,6 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
|
||||
USING_PYTEST = "pytest" in sys.modules
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
print(ICEFALL_DIR)
|
||||
|
||||
|
||||
def generate_test_data():
|
||||
|
Loading…
x
Reference in New Issue
Block a user