From 1fa30998da5ac06d4c742227cc949ed68c256df7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 31 Jul 2021 20:24:47 +0800 Subject: [PATCH] WIP: Refactoring --- .gitignore | 1 + egs/librispeech/ASR/conformer_ctc/decode.py | 6 +- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- .../ASR/conformer_ctc/transformer.py | 12 ++- egs/librispeech/ASR/local/compile_hlg.py | 36 +++---- .../ASR/local/compute_fbank_librispeech.py | 18 +++- .../ASR/local/compute_fbank_musan.py | 15 ++- egs/librispeech/ASR/local/download_lm.py | 52 +++++++-- egs/librispeech/ASR/local/prepare_lang.py | 10 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 16 +-- egs/librispeech/ASR/local/train_bpe_model.py | 9 +- egs/librispeech/ASR/prepare.sh | 101 +++++++++++------- egs/librispeech/ASR/shared | 1 + egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 2 +- icefall/lexicon.py | 30 ++++-- .../local => icefall/shared}/parse_options.sh | 0 17 files changed, 195 insertions(+), 122 deletions(-) create mode 120000 egs/librispeech/ASR/shared rename {egs/librispeech/ASR/local => icefall/shared}/parse_options.sh (100%) diff --git a/.gitignore b/.gitignore index 6cb9f2299..839a1c34a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ path.sh exp exp*/ *.pt +download/ diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index d1cbc14de..3a8db1b81 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -62,7 +62,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang/bpe"), + "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, @@ -367,15 +367,13 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt")) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() - # HLG = k2.ctc_topo(4999).to(device) - if params.method in ( "nbest-rescoring", "whole-lattice-rescoring", diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 40d3cf7fb..d411a3783 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -125,7 +125,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang/bpe"), + "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 0.0, "subsampling_factor": 4, diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 1df16e346..06027cf64 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -188,7 +188,7 @@ class Transformer(nn.Module): encoder_mask: Tensor, supervision: Supervisions = None, graph_compiler: object = None, - token_ids: List[int] = None, + token_ids: List[List[int]] = None, sos_id: Optional[int] = None, eos_id: Optional[int] = None, ) -> Tensor: @@ -199,6 +199,7 @@ class Transformer(nn.Module): supervision: Supervison in lhotse format, get from batch['supervisions'] graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) , graph_compiler.words and graph_compiler.oov + token_ids: A list of lists. Each list contains word piece IDs for an utterance. sos_id: sos token id eos_id: eos token id @@ -210,7 +211,10 @@ class Transformer(nn.Module): supervision, graph_compiler.lexicon.words, graph_compiler.oov ) ys_in_pad, ys_out_pad = add_sos_eos( - batch_text, graph_compiler.L_inv, sos_id, eos_id, + batch_text, + graph_compiler.L_inv, + sos_id, + eos_id, ) elif token_ids is not None: _sos = torch.tensor([sos_id]) @@ -225,7 +229,7 @@ class Transformer(nn.Module): ys_out_pad = pad_list(ys_out, -1) else: - raise ValueError("Invalid input for decoder self attetion") + raise ValueError("Invalid input for decoder self attention") ys_in_pad = ys_in_pad.to(x.device) ys_out_pad = ys_out_pad.to(x.device) @@ -284,7 +288,7 @@ class Transformer(nn.Module): ys_in_pad = pad_list(ys_in, eos_id) ys_out_pad = pad_list(ys_out, -1) else: - raise ValueError("Invalid input for decoder self attetion") + raise ValueError("Invalid input for decoder self attention") ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 605d72dae..c02fb7c0d 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: - The language directory, e.g., data/lang or data/lang/bpe. + The language directory, e.g., data/lang_phone or data/lang_bpe. Return: An FSA representing HLG. @@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: return HLG -def phone_based_HLG(): - if Path("data/lm/HLG.pt").is_file(): - return - - logging.info("Compiling phone based HLG") - HLG = compile_HLG("data/lang") - - logging.info("Saving HLG.pt to data/lm") - torch.save(HLG.as_dict(), "data/lm/HLG.pt") - - -def bpe_based_HLG(): - if Path("data/lm/HLG_bpe.pt").is_file(): - return - - logging.info("Compiling BPE based HLG") - HLG = compile_HLG("data/lang/bpe") - logging.info("Saving HLG_bpe.pt to data/lm") - torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt") - - def main(): - phone_based_HLG() - bpe_based_HLG() + for d in ["data/lang_phone", "data/lang_bpe"]: + d = Path(d) + logging.info(f"Processing {d}") + + if (d / "HLG.pt").is_file(): + logging.info(f"{d}/HLG.pt already exists - skipping") + continue + + HLG = compile_HLG(d) + logging.info(f"Saving HLG.pt to {d}") + torch.save(HLG.as_dict(), f"{d}/HLG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 947d9f8d9..0c07aaa1a 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 """ -This file computes fbank features of the librispeech dataset. -Its looks for manifests in the directory data/manifests -and generated fbank features are saved in data/fbank. +This file computes fbank features of the LibriSpeech dataset. +Its looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. """ +import logging import os from pathlib import Path @@ -40,9 +42,9 @@ def compute_fbank_librispeech(): with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): if (output_dir / f"cuts_{partition}.json.gz").is_file(): - print(f"{partition} already exists - skipping.") + logging.info(f"{partition} already exists - skipping.") continue - print("Processing", partition) + logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( recordings=m["recordings"], supervisions=m["supervisions"], @@ -65,4 +67,10 @@ def compute_fbank_librispeech(): if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_librispeech() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index d63131da8..6a46e6978 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -2,10 +2,12 @@ """ This file computes fbank features of the musan dataset. -Its looks for manifests in the directory data/manifests -and generated fbank features are saved in data/fbank. +Its looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. """ +import logging import os from pathlib import Path @@ -34,10 +36,10 @@ def compute_fbank_musan(): musan_cuts_path = output_dir / "cuts_musan.json.gz" if musan_cuts_path.is_file(): - print(f"{musan_cuts_path} already exists - skipping") + logging.info(f"{musan_cuts_path} already exists - skipping") return - print("Extracting features for Musan") + logging.info("Extracting features for Musan") extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -63,4 +65,9 @@ def compute_fbank_musan(): if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 0bdc2935b..5c9e2a675 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -2,10 +2,25 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This file downloads librispeech LM files to data/lm +This file downloads the following LibriSpeech LM files: + + - 3-gram.pruned.1e-7.arpa.gz + - 4-gram.arpa.gz + - librispeech-vocab.txt + - librispeech-lexicon.txt + +from http://www.openslr.org/resources/11 +and save them in the user provided directory. + +Files are not re-downloaded if they already exist. + +Usage: + ./local/download_lm.py --out-dir ./download/lm """ +import argparse import gzip +import logging import os import shutil from pathlib import Path @@ -14,9 +29,17 @@ from lhotse.utils import urlretrieve_progress from tqdm.auto import tqdm -def download_lm(): +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=str, help="Output directory.") + + args = parser.parse_args() + return args + + +def main(out_dir: str): url = "http://www.openslr.org/resources/11" - target_dir = Path("data/lm") + out_dir = Path(out_dir) files_to_download = ( "3-gram.pruned.1e-7.arpa.gz", @@ -26,7 +49,7 @@ def download_lm(): ) for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): - filename = target_dir / f + filename = out_dir / f if filename.is_file() is False: urlretrieve_progress( f"{url}/{f}", @@ -34,17 +57,26 @@ def download_lm(): desc=f"Downloading {filename}", ) else: - print(f"{filename} already exists - skipping") + logging.info(f"{filename} already exists - skipping") if ".gz" in str(filename): - unzip_file = Path(os.path.splitext(filename)[0]) - if unzip_file.is_file() is False: + unzipped = Path(os.path.splitext(filename)[0]) + if unzipped.is_file() is False: with gzip.open(filename, "rb") as f_in: - with open(unzip_file, "wb") as f_out: + with open(unzipped, "wb") as f_out: shutil.copyfileobj(f_in, f_out) else: - print(f"{unzip_file} already exist - skipping") + logging.info(f"{unzipped} already exist - skipping") if __name__ == "__main__": - download_lm() + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(f"out_dir: {args.out_dir}") + + main(out_dir=args.out_dir) diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index b9d13f5bb..f7fde7796 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -3,7 +3,7 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This script takes as input a lexicon file "data/lang/lexicon.txt" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" consisting of words and tokens (i.e., phones) and does the following: 1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt @@ -20,8 +20,6 @@ consisting of words and tokens (i.e., phones) and does the following: 5. Generate L_disambig.pt, in k2 format. """ import math -import re -import sys from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Tuple @@ -284,7 +282,9 @@ def lexicon_to_fst( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state @@ -301,7 +301,7 @@ def lexicon_to_fst( def main(): - out_dir = Path("data/lang") + out_dir = Path("data/lang_phone") lexicon_filename = out_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index 0c3e9ede5..e31220d9b 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -5,10 +5,10 @@ """ This script takes as inputs the following two files: - - data/lang/bpe/bpe.model, - - data/lang/bpe/words.txt + - data/lang_bpe/bpe.model, + - data/lang_bpe/words.txt -and generates the following files in the directory data/lang/bpe: +and generates the following files in the directory data/lang_bpe: - lexicon.txt - lexicon_disambig.txt @@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state @@ -140,7 +142,7 @@ def generate_lexicon( def main(): - lang_dir = Path("data/lang/bpe") + lang_dir = Path("data/lang_bpe") model_file = lang_dir / "bpe.model" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") @@ -173,7 +175,9 @@ def main(): write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst_no_sil( - lexicon, token2id=token_sym_table, word2id=word_sym_table, + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, ) L_disambig = lexicon_to_fst_no_sil( diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index b5c6c7541..59746ad9a 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -14,18 +14,17 @@ and generates "data/lang/bpe/bep.model". # # Please install a version >=0.1.96 +import shutil from pathlib import Path import sentencepiece as spm -import shutil - def main(): model_type = "unigram" vocab_size = 5000 - model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}" - train_text = "data/lang/bpe/train.txt" + model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}" + train_text = "data/lang_bpe/train.txt" character_coverage = 1.0 input_sentence_size = 100000000 @@ -53,7 +52,7 @@ def main(): sp = spm.SentencePieceProcessor(model_file=str(model_file)) vocab_size = sp.vocab_size() - shutil.copyfile(model_file, "data/lang/bpe/bpe.model") + shutil.copyfile(model_file, "data/lang_bpe/bpe.model") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 406527b71..ae676b199 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -6,8 +6,38 @@ nj=15 stage=-1 stop_stage=100 -. local/parse_options.sh || exit 1 +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# +# - $do_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download +. shared/parse_options.sh || exit 1 + + +# All generated files by this script are saved in "data" mkdir -p data log() { @@ -16,10 +46,11 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "dl_dir: $dl_dir" + if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "stage -1: Download LM" - mkdir -p data/lm - ./local/download_lm.py + ./local/download_lm.py --out-dir=$dl_dir/lm fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then @@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # If you have pre-downloaded it to /path/to/LibriSpeech, # you can create a symlink # - # ln -sfv /path/to/LibriSpeech data/ + # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech # - # The script checks that if - # - # data/LibriSpeech/test-clean/.completed exists, - # - # it will not re-download it. - # - # The same goes for dev-clean, dev-other, test-other, train-clean-100 - # train-clean-360, and train-other-500 - - mkdir -p data/LibriSpeech - lhotse download librispeech --full data + if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then + lhotse download librispeech --full $dl_dir + fi # If you have pre-downloaded it to /path/to/musan, # you can create a symlink # - # ln -sfv /path/to/musan data/ + # ln -sfv /path/to/musan $dl_dir/ # - # and create a file data/.musan_completed - # to avoid downloading it again - if [ ! -f data/.musan_completed ]; then - lhotse download musan data + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare librispeech manifest" - # We assume that you have downloaded the librispeech corpus - # to data/LibriSpeech + log "Stage 1: Prepare LibriSpeech manifest" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech mkdir -p data/manifests - lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests + lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # We assume that you have downloaded the musan corpus # to data/musan mkdir -p data/manifests - lhotse prepare musan data/musan data/manifests + lhotse prepare musan $dl_dir/musan data/manifests fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then @@ -84,24 +105,25 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" - # TODO: add BPE based lang - mkdir -p data/lang + mkdir -p data/lang_phone (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - data/lm/librispeech-lexicon.txt | - sort | uniq > data/lang/lexicon.txt + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > data/lang_phone/lexicon.txt - if [ ! -f data/lang/L_disambig.pt ]; then + if [ ! -f data/lang_phone/L_disambig.pt ]; then ./local/prepare_lang.py fi fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "State 6: Prepare BPE based lang" - mkdir -p data/lang/bpe - cp data/lang/words.txt data/lang/bpe/ + mkdir -p data/lang_bpe + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt data/lang_bpe/ - if [ ! -f data/lang/bpe/train.txt ]; then + if [ ! -f data/lang_bpe/train.txt ]; then log "Generate data for BPE training" files=$( find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" @@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then ) for f in ${files[@]}; do cat $f | cut -d " " -f 2- - done > data/lang/bpe/train.txt + done > data/lang_bpe/train.txt fi python3 ./local/train_bpe_model.py - if [ ! -f data/lang/bpe/L_disambig.pt ]; then + if [ ! -f data/lang_bpe/L_disambig.pt ]; then ./local/prepare_lang_bpe.py fi fi @@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then # We assume you have install kaldilm, if not, please install # it using: pip install kaldilm + mkdir -p data/lm if [ ! -f data/lm/G_3_gram.fst.txt ]; then # It is used in building HLG python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ + --read-symbol-table="data/lang_phone/words.txt" \ --disambig-symbol='#0' \ --max-order=3 \ - data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt fi if [ ! -f data/lm/G_4_gram.fst.txt ]; then # It is used for LM rescoring python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ + --read-symbol-table="data/lang_phone/words.txt" \ --disambig-symbol='#0' \ --max-order=4 \ - data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt fi fi diff --git a/egs/librispeech/ASR/shared b/egs/librispeech/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/librispeech/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 2c45b4e31..137fa795c 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -58,7 +58,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("tdnn_lstm_ctc/exp/"), - "lang_dir": Path("data/lang"), + "lang_dir": Path("data/lang_phone"), "lm_dir": Path("data/lm"), "feature_dim": 80, "subsampling_factor": 3, @@ -328,7 +328,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt")) + HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -340,7 +340,7 @@ def main(): logging.info("Loading G_4_gram.fst.txt") logging.warning("It may take 8 minutes.") with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.words["#0"] + first_word_disambig_id = lexicon.word_table["#0"] G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 3330b07a5..dbb9f64ec 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -127,7 +127,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("tdnn_lstm_ctc/exp"), - "lang_dir": Path("data/lang"), + "lang_dir": Path("data/lang_phone"), "lr": 1e-3, "feature_dim": 80, "weight_decay": 5e-4, diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 3b52c70c9..89747b11b 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -1,7 +1,8 @@ import logging import re +import sys from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Tuple import k2 import torch @@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]: continue if len(a) < 2: - print(f"Found bad line {line} in lexicon file {filename}") - print("Every line is expected to contain at least 2 fields") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) + logging.info( + "Every line is expected to contain at least 2 fields" + ) sys.exit(1) word = a[0] if word == "": - print(f"Found bad line {line} in lexicon file {filename}") - print(" should not be a valid word") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) + logging.info(" should not be a valid word") sys.exit(1) tokens = a[1:] @@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None: class Lexicon(object): - """Phone based lexicon. - - TODO: Add BpeLexicon for BPE models. - """ + """Phone based lexicon.""" def __init__( - self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), + self, + lang_dir: Path, + disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Args: @@ -121,7 +127,9 @@ class Lexicon(object): class BpeLexicon(Lexicon): def __init__( - self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), + self, + lang_dir: Path, + disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Refer to the help information in Lexicon.__init__. diff --git a/egs/librispeech/ASR/local/parse_options.sh b/icefall/shared/parse_options.sh similarity index 100% rename from egs/librispeech/ASR/local/parse_options.sh rename to icefall/shared/parse_options.sh