Add prepare_lang.py based on prepare_lang.sh

This commit is contained in:
Fangjun Kuang 2021-07-20 19:41:21 +08:00
parent e005ea062c
commit d5e0408698
3 changed files with 467 additions and 3 deletions

View File

@ -29,10 +29,12 @@ def download_lm():
filename = target_dir / f filename = target_dir / f
if filename.is_file() is False: if filename.is_file() is False:
urlretrieve_progress( urlretrieve_progress(
f"{url}/{f}", filename=filename, desc=f"Downloading {filename}", f"{url}/{f}",
filename=filename,
desc=f"Downloading {filename}",
) )
else: else:
print(f'{filename} already exists - skipping') print(f"{filename} already exists - skipping")
if ".gz" in str(filename): if ".gz" in str(filename):
unzip_file = Path(os.path.splitext(filename)[0]) unzip_file = Path(os.path.splitext(filename)[0])
@ -41,7 +43,7 @@ def download_lm():
with open(unzip_file, "wb") as f_out: with open(unzip_file, "wb") as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
else: else:
print(f'{unzip_file} already exist - skipping') print(f"{unzip_file} already exist - skipping")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,376 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang/lexicon.txt"
consisting of words and phones and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate phones.txt, the phones table mapping a phone to a unique integer.
3. Generate words.txt, the words table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
6. Generate lexicon_disambig.txt
"""
import math
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import torch
Lexicon = List[Tuple[str, List[str]]]
def read_lexicon(filename: str) -> Lexicon:
"""Read a lexicon.txt in `filename`.
Each line in the lexicon contains "word p1 p2 p3 ...".
That is, the first field is a word and the remaining
fields are phones. Fields are separated by space(s).
We assume that the input lexicon does not contain words:
<eps>, <s>, </s>, !SIL, <SPOKEN_NOISE>, <UNK>.
Args:
filename:
Path to the lexicon.txt
Returns:
A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
"""
# ans = ["!SIL", ["SIL"]]
# ans.append(["<SPOKEN_NOISE>", ["SPN"]])
# ans.append(["<UNK>", ["SPN"]])
ans = []
with open(filename, "r", encoding="latin-1") as f:
whitespace = re.compile("[ \t]+")
for line in f:
a = whitespace.split(line.strip(" \t\r\n"))
if len(a) == 0:
continue
if len(a) < 2:
print(f"Found bad line {line} in lexicon file {filename}")
print("Every line is expected to contain at least 2 fields")
sys.exit(1)
word = a[0]
if word == "<eps>":
print(f"Found bad line {line} in lexicon file {filename}")
print("<eps> should not be a valid word")
sys.exit(1)
prons = a[1:]
ans.append((word, prons))
return ans
def write_lexicon(filename: str, lexicon: Lexicon) -> None:
"""Write a lexicon to a file.
Args:
filename:
Path to the lexicon file to be generated.
lexicon:
It can be the return value of :func:`read_lexicon`.
"""
with open(filename, "w") as f:
for word, prons in lexicon:
f.write(f"{word} {' '.join(prons)}\n")
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_phones(lexicon: Lexicon) -> List[str]:
"""Get phones from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique phones.
"""
ans = set()
for _, prons in lexicon:
ans.update(prons)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: List[Tuple[str, List[str]]]) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-phone disambiguation symbols #1, #2 and so on
at the ends of phones to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbols
"""
# (1) Work out the count of each phone-sequence in the
# lexicon.
count = defaultdict(int)
for _, prons in lexicon:
count[" ".join(prons)] += 1
# (2) For each left sub-sequence of each phone-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, prons in lexicon:
prons = prons.copy()
prons.pop()
while prons:
issubseq[" ".join(prons)] = 1
prons.pop()
# (3) For each entry in the lexicon:
# if the phone sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same phone-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, prons in lexicon:
phnseq = " ".join(prons)
assert phnseq != ""
if issubseq[phnseq] == 0 and count[phnseq] == 1:
ans.append((word, prons))
continue
cur_disambig = last_used_disambig_symbol_of[phnseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[phnseq] = cur_disambig
phnseq += f" #{cur_disambig}"
ans.append((word, phnseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def lexicon_to_fst(
lexicon: Lexicon,
phone2id: Dict[str, int],
word2id: Dict[str, int],
sil_phone: str = "SIL",
sil_prob: float = 0.5,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of the word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
phone2id:
A dict mapping phones to IDs.
word2id:
A dict mapping words to IDs.
sil_phone:
The silence phone.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert phone2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_phone = phone2id[sil_phone]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_phone, eps, 0])
for word, prons in lexicon:
assert len(prons) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
prons = [phone2id[i] for i in prons]
for i in range(len(prons) - 1):
if i == 0:
arcs.append([cur_state, next_state, prons[i], word, 0])
else:
arcs.append([cur_state, next_state, prons[i], eps, 0])
cur_state = next_state
next_state += 1
# now for the last phone of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(prons) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, prons[i], w, no_sil_score])
arcs.append([cur_state, sil_state, prons[i], w, sil_score])
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def main():
out_dir = Path("data/lang")
lexicon_filename = out_dir / "lexicon.txt"
sil_phone = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
phones = get_phones(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in phones
phones.append(f"#{i}")
assert "<eps>" not in phones
phones = ["<eps>"] + phones
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
phone2id = generate_id_map(phones)
word2id = generate_id_map(words)
write_mapping(out_dir / "phones.txt", phone2id)
write_mapping(out_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
phone2id=phone2id,
word2id=word2id,
sil_phone=sil_phone,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
phone2id=phone2id,
word2id=word2id,
sil_phone=sil_phone,
sil_prob=sil_prob,
)
# TODO(fangjun): add self-loops to L_disambig
# whose ilabel is phone2id['#0'] and olable is word2id['#0']
# Need to implement it in k2
if False:
# Just for debugging, will remove it
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.svg", title="L")
L_disambig.draw(out_dir / "L_disambig.svg", title="L_disambig")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import os
import tempfile
import k2
from prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,
get_words,
lexicon_to_fst,
read_lexicon,
write_lexicon,
write_mapping,
)
def generate_lexicon_file() -> str:
fd, filename = tempfile.mkstemp()
os.close(fd)
s = """
!SIL SIL
<SPOKEN_NOISE> SPN
<UNK> SPN
f f
a a
foo f o o
bar b a r
bark b a r k
food f o o d
food2 f o o d
fo f o
""".strip()
with open(filename, "w") as f:
f.write(s)
return filename
def test_read_lexicon(filename: str):
lexicon = read_lexicon(filename)
phones = get_phones(lexicon)
words = get_words(lexicon)
print(lexicon)
print(phones)
print(words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
print(lexicon_disambig)
print("max disambig:", f"#{max_disambig}")
phones = ["<eps>", "SIL", "SPN"] + phones
for i in range(max_disambig + 1):
phones.append(f"#{i}")
words = ["<eps>"] + words
phone2id = generate_id_map(phones)
word2id = generate_id_map(words)
print(phone2id)
print(word2id)
write_mapping("phones.txt", phone2id)
write_mapping("words.txt", word2id)
write_lexicon("a.txt", lexicon)
write_lexicon("a_disambig.txt", lexicon_disambig)
fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
if __name__ == "__main__":
filename = generate_lexicon_file()
test_read_lexicon(filename)
os.remove(filename)