mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add prepare_lang.py based on prepare_lang.sh
This commit is contained in:
parent
e005ea062c
commit
d5e0408698
@ -29,10 +29,12 @@ def download_lm():
|
||||
filename = target_dir / f
|
||||
if filename.is_file() is False:
|
||||
urlretrieve_progress(
|
||||
f"{url}/{f}", filename=filename, desc=f"Downloading {filename}",
|
||||
f"{url}/{f}",
|
||||
filename=filename,
|
||||
desc=f"Downloading {filename}",
|
||||
)
|
||||
else:
|
||||
print(f'{filename} already exists - skipping')
|
||||
print(f"{filename} already exists - skipping")
|
||||
|
||||
if ".gz" in str(filename):
|
||||
unzip_file = Path(os.path.splitext(filename)[0])
|
||||
@ -41,7 +43,7 @@ def download_lm():
|
||||
with open(unzip_file, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
else:
|
||||
print(f'{unzip_file} already exist - skipping')
|
||||
print(f"{unzip_file} already exist - skipping")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
376
egs/librispeech/ASR/local/prepare_lang.py
Executable file
376
egs/librispeech/ASR/local/prepare_lang.py
Executable file
@ -0,0 +1,376 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang/lexicon.txt"
|
||||
consisting of words and phones and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
|
||||
2. Generate phones.txt, the phones table mapping a phone to a unique integer.
|
||||
|
||||
3. Generate words.txt, the words table mapping a word to a unique integer.
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
||||
6. Generate lexicon_disambig.txt
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def read_lexicon(filename: str) -> Lexicon:
|
||||
"""Read a lexicon.txt in `filename`.
|
||||
|
||||
Each line in the lexicon contains "word p1 p2 p3 ...".
|
||||
That is, the first field is a word and the remaining
|
||||
fields are phones. Fields are separated by space(s).
|
||||
|
||||
We assume that the input lexicon does not contain words:
|
||||
<eps>, <s>, </s>, !SIL, <SPOKEN_NOISE>, <UNK>.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the lexicon.txt
|
||||
|
||||
Returns:
|
||||
A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
|
||||
"""
|
||||
# ans = ["!SIL", ["SIL"]]
|
||||
# ans.append(["<SPOKEN_NOISE>", ["SPN"]])
|
||||
# ans.append(["<UNK>", ["SPN"]])
|
||||
|
||||
ans = []
|
||||
|
||||
with open(filename, "r", encoding="latin-1") as f:
|
||||
whitespace = re.compile("[ \t]+")
|
||||
for line in f:
|
||||
a = whitespace.split(line.strip(" \t\r\n"))
|
||||
if len(a) == 0:
|
||||
continue
|
||||
|
||||
if len(a) < 2:
|
||||
print(f"Found bad line {line} in lexicon file {filename}")
|
||||
print("Every line is expected to contain at least 2 fields")
|
||||
sys.exit(1)
|
||||
word = a[0]
|
||||
if word == "<eps>":
|
||||
print(f"Found bad line {line} in lexicon file {filename}")
|
||||
print("<eps> should not be a valid word")
|
||||
sys.exit(1)
|
||||
|
||||
prons = a[1:]
|
||||
ans.append((word, prons))
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def write_lexicon(filename: str, lexicon: Lexicon) -> None:
|
||||
"""Write a lexicon to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the lexicon file to be generated.
|
||||
lexicon:
|
||||
It can be the return value of :func:`read_lexicon`.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
for word, prons in lexicon:
|
||||
f.write(f"{word} {' '.join(prons)}\n")
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_phones(lexicon: Lexicon) -> List[str]:
|
||||
"""Get phones from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique phones.
|
||||
"""
|
||||
ans = set()
|
||||
for _, prons in lexicon:
|
||||
ans.update(prons)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def get_words(lexicon: List[Tuple[str, List[str]]]) -> List[str]:
|
||||
"""Get words from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique words.
|
||||
"""
|
||||
ans = set()
|
||||
for word, _ in lexicon:
|
||||
ans.add(word)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||
"""It adds pseudo-phone disambiguation symbols #1, #2 and so on
|
||||
at the ends of phones to ensure that all pronunciations are different,
|
||||
and that none is a prefix of another.
|
||||
|
||||
See also add_lex_disambig.pl from kaldi.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is returned by :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
|
||||
- The output lexicon with disambiguation symbols
|
||||
- The ID of the max disambiguation symbols
|
||||
"""
|
||||
|
||||
# (1) Work out the count of each phone-sequence in the
|
||||
# lexicon.
|
||||
count = defaultdict(int)
|
||||
for _, prons in lexicon:
|
||||
count[" ".join(prons)] += 1
|
||||
|
||||
# (2) For each left sub-sequence of each phone-sequence, note down
|
||||
# that it exists (for identifying prefixes of longer strings).
|
||||
issubseq = defaultdict(int)
|
||||
for _, prons in lexicon:
|
||||
prons = prons.copy()
|
||||
prons.pop()
|
||||
while prons:
|
||||
issubseq[" ".join(prons)] = 1
|
||||
prons.pop()
|
||||
|
||||
# (3) For each entry in the lexicon:
|
||||
# if the phone sequence is unique and is not a
|
||||
# prefix of another word, no disambig symbol.
|
||||
# Else output #1, or #2, #3, ... if the same phone-seq
|
||||
# has already been assigned a disambig symbol.
|
||||
ans = []
|
||||
|
||||
# We start with #1 since #0 has its own purpose
|
||||
first_allowed_disambig = 1
|
||||
max_disambig = first_allowed_disambig - 1
|
||||
last_used_disambig_symbol_of = defaultdict(int)
|
||||
|
||||
for word, prons in lexicon:
|
||||
phnseq = " ".join(prons)
|
||||
assert phnseq != ""
|
||||
if issubseq[phnseq] == 0 and count[phnseq] == 1:
|
||||
ans.append((word, prons))
|
||||
continue
|
||||
|
||||
cur_disambig = last_used_disambig_symbol_of[phnseq]
|
||||
if cur_disambig == 0:
|
||||
cur_disambig = first_allowed_disambig
|
||||
else:
|
||||
cur_disambig += 1
|
||||
|
||||
if cur_disambig > max_disambig:
|
||||
max_disambig = cur_disambig
|
||||
last_used_disambig_symbol_of[phnseq] = cur_disambig
|
||||
phnseq += f" #{cur_disambig}"
|
||||
ans.append((word, phnseq.split()))
|
||||
return ans, max_disambig
|
||||
|
||||
|
||||
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||
|
||||
Args:
|
||||
symbols:
|
||||
A list of unique symbols.
|
||||
Returns:
|
||||
A dict containing the mapping between symbols and IDs.
|
||||
"""
|
||||
return {sym: i for i, sym in enumerate(symbols)}
|
||||
|
||||
|
||||
def lexicon_to_fst(
|
||||
lexicon: Lexicon,
|
||||
phone2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
sil_phone: str = "SIL",
|
||||
sil_prob: float = 0.5,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||
the beginning and end of the word.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
phone2id:
|
||||
A dict mapping phones to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
sil_phone:
|
||||
The silence phone.
|
||||
sil_prob:
|
||||
The probability for adding a silence at the beginning and end
|
||||
of the word.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||
# CAUTION: we use score, i.e, negative cost.
|
||||
sil_score = math.log(sil_prob)
|
||||
no_sil_score = math.log(1.0 - sil_prob)
|
||||
|
||||
start_state = 0
|
||||
loop_state = 1 # words enter and leave from here
|
||||
sil_state = 2 # words terminate here when followed by silence; this state
|
||||
# has a silence transition to loop_state.
|
||||
next_state = 3 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
assert phone2id["<eps>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
sil_phone = phone2id[sil_phone]
|
||||
|
||||
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||
arcs.append([sil_state, loop_state, sil_phone, eps, 0])
|
||||
|
||||
for word, prons in lexicon:
|
||||
assert len(prons) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
prons = [phone2id[i] for i in prons]
|
||||
|
||||
for i in range(len(prons) - 1):
|
||||
if i == 0:
|
||||
arcs.append([cur_state, next_state, prons[i], word, 0])
|
||||
else:
|
||||
arcs.append([cur_state, next_state, prons[i], eps, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last phone of this word
|
||||
# It has two out-going arcs, one to the loop state,
|
||||
# the other one to the sil_state.
|
||||
i = len(prons) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, prons[i], w, no_sil_score])
|
||||
arcs.append([cur_state, sil_state, prons[i], w, sil_score])
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("data/lang")
|
||||
lexicon_filename = out_dir / "lexicon.txt"
|
||||
sil_phone = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
||||
lexicon = read_lexicon(lexicon_filename)
|
||||
phones = get_phones(lexicon)
|
||||
words = get_words(lexicon)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in phones
|
||||
phones.append(f"#{i}")
|
||||
|
||||
assert "<eps>" not in phones
|
||||
phones = ["<eps>"] + phones
|
||||
|
||||
assert "<eps>" not in words
|
||||
assert "#0" not in words
|
||||
assert "<s>" not in words
|
||||
assert "</s>" not in words
|
||||
|
||||
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
|
||||
|
||||
phone2id = generate_id_map(phones)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(out_dir / "phones.txt", phone2id)
|
||||
write_mapping(out_dir / "words.txt", word2id)
|
||||
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
phone2id=phone2id,
|
||||
word2id=word2id,
|
||||
sil_phone=sil_phone,
|
||||
sil_prob=sil_prob,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst(
|
||||
lexicon_disambig,
|
||||
phone2id=phone2id,
|
||||
word2id=word2id,
|
||||
sil_phone=sil_phone,
|
||||
sil_prob=sil_prob,
|
||||
)
|
||||
|
||||
# TODO(fangjun): add self-loops to L_disambig
|
||||
# whose ilabel is phone2id['#0'] and olable is word2id['#0']
|
||||
# Need to implement it in k2
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
||||
|
||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(out_dir / "L.svg", title="L")
|
||||
L_disambig.draw(out_dir / "L_disambig.svg", title="L_disambig")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
86
egs/librispeech/ASR/local/test_prepare_lang.py
Executable file
86
egs/librispeech/ASR/local/test_prepare_lang.py
Executable file
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import k2
|
||||
from prepare_lang import (
|
||||
add_disambig_symbols,
|
||||
generate_id_map,
|
||||
get_phones,
|
||||
get_words,
|
||||
lexicon_to_fst,
|
||||
read_lexicon,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
|
||||
def generate_lexicon_file() -> str:
|
||||
fd, filename = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
s = """
|
||||
!SIL SIL
|
||||
<SPOKEN_NOISE> SPN
|
||||
<UNK> SPN
|
||||
f f
|
||||
a a
|
||||
foo f o o
|
||||
bar b a r
|
||||
bark b a r k
|
||||
food f o o d
|
||||
food2 f o o d
|
||||
fo f o
|
||||
""".strip()
|
||||
with open(filename, "w") as f:
|
||||
f.write(s)
|
||||
return filename
|
||||
|
||||
|
||||
def test_read_lexicon(filename: str):
|
||||
lexicon = read_lexicon(filename)
|
||||
phones = get_phones(lexicon)
|
||||
words = get_words(lexicon)
|
||||
print(lexicon)
|
||||
print(phones)
|
||||
print(words)
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
print(lexicon_disambig)
|
||||
print("max disambig:", f"#{max_disambig}")
|
||||
|
||||
phones = ["<eps>", "SIL", "SPN"] + phones
|
||||
for i in range(max_disambig + 1):
|
||||
phones.append(f"#{i}")
|
||||
words = ["<eps>"] + words
|
||||
|
||||
phone2id = generate_id_map(phones)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
print(phone2id)
|
||||
print(word2id)
|
||||
|
||||
write_mapping("phones.txt", phone2id)
|
||||
write_mapping("words.txt", word2id)
|
||||
|
||||
write_lexicon("a.txt", lexicon)
|
||||
write_lexicon("a_disambig.txt", lexicon_disambig)
|
||||
|
||||
fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
|
||||
fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa.draw("L.pdf", title="L")
|
||||
|
||||
fsa_disambig = lexicon_to_fst(
|
||||
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||
)
|
||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filename = generate_lexicon_file()
|
||||
test_read_lexicon(filename)
|
||||
os.remove(filename)
|
Loading…
x
Reference in New Issue
Block a user