diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 47251a5a0..0bdc2935b 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -29,10 +29,12 @@ def download_lm(): filename = target_dir / f if filename.is_file() is False: urlretrieve_progress( - f"{url}/{f}", filename=filename, desc=f"Downloading {filename}", + f"{url}/{f}", + filename=filename, + desc=f"Downloading {filename}", ) else: - print(f'{filename} already exists - skipping') + print(f"{filename} already exists - skipping") if ".gz" in str(filename): unzip_file = Path(os.path.splitext(filename)[0]) @@ -41,7 +43,7 @@ def download_lm(): with open(unzip_file, "wb") as f_out: shutil.copyfileobj(f_in, f_out) else: - print(f'{unzip_file} already exist - skipping') + print(f"{unzip_file} already exist - skipping") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py new file mode 100755 index 000000000..a568cc8a9 --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as input a lexicon file "data/lang/lexicon.txt" +consisting of words and phones and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate phones.txt, the phones table mapping a phone to a unique integer. + +3. Generate words.txt, the words table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. + +6. Generate lexicon_disambig.txt +""" +import math +import re +import sys +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import torch + +Lexicon = List[Tuple[str, List[str]]] + + +def read_lexicon(filename: str) -> Lexicon: + """Read a lexicon.txt in `filename`. + + Each line in the lexicon contains "word p1 p2 p3 ...". + That is, the first field is a word and the remaining + fields are phones. Fields are separated by space(s). + + We assume that the input lexicon does not contain words: + , , , !SIL, , . + + Args: + filename: + Path to the lexicon.txt + + Returns: + A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])] + """ + # ans = ["!SIL", ["SIL"]] + # ans.append(["", ["SPN"]]) + # ans.append(["", ["SPN"]]) + + ans = [] + + with open(filename, "r", encoding="latin-1") as f: + whitespace = re.compile("[ \t]+") + for line in f: + a = whitespace.split(line.strip(" \t\r\n")) + if len(a) == 0: + continue + + if len(a) < 2: + print(f"Found bad line {line} in lexicon file {filename}") + print("Every line is expected to contain at least 2 fields") + sys.exit(1) + word = a[0] + if word == "": + print(f"Found bad line {line} in lexicon file {filename}") + print(" should not be a valid word") + sys.exit(1) + + prons = a[1:] + ans.append((word, prons)) + + return ans + + +def write_lexicon(filename: str, lexicon: Lexicon) -> None: + """Write a lexicon to a file. + + Args: + filename: + Path to the lexicon file to be generated. + lexicon: + It can be the return value of :func:`read_lexicon`. + """ + with open(filename, "w") as f: + for word, prons in lexicon: + f.write(f"{word} {' '.join(prons)}\n") + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_phones(lexicon: Lexicon) -> List[str]: + """Get phones from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique phones. + """ + ans = set() + for _, prons in lexicon: + ans.update(prons) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: List[Tuple[str, List[str]]]) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-phone disambiguation symbols #1, #2 and so on + at the ends of phones to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbols + """ + + # (1) Work out the count of each phone-sequence in the + # lexicon. + count = defaultdict(int) + for _, prons in lexicon: + count[" ".join(prons)] += 1 + + # (2) For each left sub-sequence of each phone-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, prons in lexicon: + prons = prons.copy() + prons.pop() + while prons: + issubseq[" ".join(prons)] = 1 + prons.pop() + + # (3) For each entry in the lexicon: + # if the phone sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same phone-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, prons in lexicon: + phnseq = " ".join(prons) + assert phnseq != "" + if issubseq[phnseq] == 0 and count[phnseq] == 1: + ans.append((word, prons)) + continue + + cur_disambig = last_used_disambig_symbol_of[phnseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[phnseq] = cur_disambig + phnseq += f" #{cur_disambig}" + ans.append((word, phnseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def lexicon_to_fst( + lexicon: Lexicon, + phone2id: Dict[str, int], + word2id: Dict[str, int], + sil_phone: str = "SIL", + sil_prob: float = 0.5, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of the word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + phone2id: + A dict mapping phones to IDs. + word2id: + A dict mapping words to IDs. + sil_phone: + The silence phone. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert phone2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_phone = phone2id[sil_phone] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_phone, eps, 0]) + + for word, prons in lexicon: + assert len(prons) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + prons = [phone2id[i] for i in prons] + + for i in range(len(prons) - 1): + if i == 0: + arcs.append([cur_state, next_state, prons[i], word, 0]) + else: + arcs.append([cur_state, next_state, prons[i], eps, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last phone of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(prons) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, prons[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, prons[i], w, sil_score]) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def main(): + out_dir = Path("data/lang") + lexicon_filename = out_dir / "lexicon.txt" + sil_phone = "SIL" + sil_prob = 0.5 + + lexicon = read_lexicon(lexicon_filename) + phones = get_phones(lexicon) + words = get_words(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in phones + phones.append(f"#{i}") + + assert "" not in phones + phones = [""] + phones + + assert "" not in words + assert "#0" not in words + assert "" not in words + assert "" not in words + + words = [""] + words + ["#0", "", ""] + + phone2id = generate_id_map(phones) + word2id = generate_id_map(words) + + write_mapping(out_dir / "phones.txt", phone2id) + write_mapping(out_dir / "words.txt", word2id) + write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + phone2id=phone2id, + word2id=word2id, + sil_phone=sil_phone, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + phone2id=phone2id, + word2id=word2id, + sil_phone=sil_phone, + sil_prob=sil_prob, + ) + + # TODO(fangjun): add self-loops to L_disambig + # whose ilabel is phone2id['#0'] and olable is word2id['#0'] + # Need to implement it in k2 + + if False: + # Just for debugging, will remove it + torch.save(L.as_dict(), out_dir / "L.pt") + torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") + + L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") + L_disambig.labels_sym = L.labels_sym + L_disambig.aux_labels_sym = L.aux_labels_sym + L.draw(out_dir / "L.svg", title="L") + L_disambig.draw(out_dir / "L_disambig.svg", title="L_disambig") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py new file mode 100755 index 000000000..f36ef55c6 --- /dev/null +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import os +import tempfile + +import k2 +from prepare_lang import ( + add_disambig_symbols, + generate_id_map, + get_phones, + get_words, + lexicon_to_fst, + read_lexicon, + write_lexicon, + write_mapping, +) + + +def generate_lexicon_file() -> str: + fd, filename = tempfile.mkstemp() + os.close(fd) + s = """ + !SIL SIL + SPN + SPN + f f + a a + foo f o o + bar b a r + bark b a r k + food f o o d + food2 f o o d + fo f o + """.strip() + with open(filename, "w") as f: + f.write(s) + return filename + + +def test_read_lexicon(filename: str): + lexicon = read_lexicon(filename) + phones = get_phones(lexicon) + words = get_words(lexicon) + print(lexicon) + print(phones) + print(words) + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + print(lexicon_disambig) + print("max disambig:", f"#{max_disambig}") + + phones = ["", "SIL", "SPN"] + phones + for i in range(max_disambig + 1): + phones.append(f"#{i}") + words = [""] + words + + phone2id = generate_id_map(phones) + word2id = generate_id_map(words) + + print(phone2id) + print(word2id) + + write_mapping("phones.txt", phone2id) + write_mapping("words.txt", word2id) + + write_lexicon("a.txt", lexicon) + write_lexicon("a_disambig.txt", lexicon_disambig) + + fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) + fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") + fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") + fsa.draw("L.pdf", title="L") + + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) + fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") + fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") + fsa_disambig.draw("L_disambig.pdf", title="L_disambig") + + +if __name__ == "__main__": + filename = generate_lexicon_file() + test_read_lexicon(filename) + os.remove(filename)