WIP: Refactoring

This commit is contained in:
Fangjun Kuang 2021-07-31 20:24:47 +08:00
parent c72a11ea1f
commit 1fa30998da
17 changed files with 195 additions and 122 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ path.sh
exp exp
exp*/ exp*/
*.pt *.pt
download/

View File

@ -62,7 +62,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"), "lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
"feature_dim": 80, "feature_dim": 80,
"nhead": 8, "nhead": 8,
@ -367,15 +367,13 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt")) HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone() HLG.lm_scores = HLG.scores.clone()
# HLG = k2.ctc_topo(4999).to(device)
if params.method in ( if params.method in (
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",

View File

@ -125,7 +125,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 0.0, "weight_decay": 0.0,
"subsampling_factor": 4, "subsampling_factor": 4,

View File

@ -188,7 +188,7 @@ class Transformer(nn.Module):
encoder_mask: Tensor, encoder_mask: Tensor,
supervision: Supervisions = None, supervision: Supervisions = None,
graph_compiler: object = None, graph_compiler: object = None,
token_ids: List[int] = None, token_ids: List[List[int]] = None,
sos_id: Optional[int] = None, sos_id: Optional[int] = None,
eos_id: Optional[int] = None, eos_id: Optional[int] = None,
) -> Tensor: ) -> Tensor:
@ -199,6 +199,7 @@ class Transformer(nn.Module):
supervision: Supervison in lhotse format, get from batch['supervisions'] supervision: Supervison in lhotse format, get from batch['supervisions']
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
, graph_compiler.words and graph_compiler.oov , graph_compiler.words and graph_compiler.oov
token_ids: A list of lists. Each list contains word piece IDs for an utterance.
sos_id: sos token id sos_id: sos token id
eos_id: eos token id eos_id: eos token id
@ -210,7 +211,10 @@ class Transformer(nn.Module):
supervision, graph_compiler.lexicon.words, graph_compiler.oov supervision, graph_compiler.lexicon.words, graph_compiler.oov
) )
ys_in_pad, ys_out_pad = add_sos_eos( ys_in_pad, ys_out_pad = add_sos_eos(
batch_text, graph_compiler.L_inv, sos_id, eos_id, batch_text,
graph_compiler.L_inv,
sos_id,
eos_id,
) )
elif token_ids is not None: elif token_ids is not None:
_sos = torch.tensor([sos_id]) _sos = torch.tensor([sos_id])
@ -225,7 +229,7 @@ class Transformer(nn.Module):
ys_out_pad = pad_list(ys_out, -1) ys_out_pad = pad_list(ys_out, -1)
else: else:
raise ValueError("Invalid input for decoder self attetion") raise ValueError("Invalid input for decoder self attention")
ys_in_pad = ys_in_pad.to(x.device) ys_in_pad = ys_in_pad.to(x.device)
ys_out_pad = ys_out_pad.to(x.device) ys_out_pad = ys_out_pad.to(x.device)
@ -284,7 +288,7 @@ class Transformer(nn.Module):
ys_in_pad = pad_list(ys_in, eos_id) ys_in_pad = pad_list(ys_in, eos_id)
ys_out_pad = pad_list(ys_out, -1) ys_out_pad = pad_list(ys_out, -1)
else: else:
raise ValueError("Invalid input for decoder self attetion") raise ValueError("Invalid input for decoder self attention")
ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64)

View File

@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
""" """
Args: Args:
lang_dir: lang_dir:
The language directory, e.g., data/lang or data/lang/bpe. The language directory, e.g., data/lang_phone or data/lang_bpe.
Return: Return:
An FSA representing HLG. An FSA representing HLG.
@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
return HLG return HLG
def phone_based_HLG():
if Path("data/lm/HLG.pt").is_file():
return
logging.info("Compiling phone based HLG")
HLG = compile_HLG("data/lang")
logging.info("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
def bpe_based_HLG():
if Path("data/lm/HLG_bpe.pt").is_file():
return
logging.info("Compiling BPE based HLG")
HLG = compile_HLG("data/lang/bpe")
logging.info("Saving HLG_bpe.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
def main(): def main():
phone_based_HLG() for d in ["data/lang_phone", "data/lang_bpe"]:
bpe_based_HLG() d = Path(d)
logging.info(f"Processing {d}")
if (d / "HLG.pt").is_file():
logging.info(f"{d}/HLG.pt already exists - skipping")
continue
HLG = compile_HLG(d)
logging.info(f"Saving HLG.pt to {d}")
torch.save(HLG.as_dict(), f"{d}/HLG.pt")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
This file computes fbank features of the librispeech dataset. This file computes fbank features of the LibriSpeech dataset.
Its looks for manifests in the directory data/manifests Its looks for manifests in the directory data/manifests.
and generated fbank features are saved in data/fbank.
The generated fbank features are saved in data/fbank.
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@ -40,9 +42,9 @@ def compute_fbank_librispeech():
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file(): if (output_dir / f"cuts_{partition}.json.gz").is_file():
print(f"{partition} already exists - skipping.") logging.info(f"{partition} already exists - skipping.")
continue continue
print("Processing", partition) logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests( cut_set = CutSet.from_manifests(
recordings=m["recordings"], recordings=m["recordings"],
supervisions=m["supervisions"], supervisions=m["supervisions"],
@ -65,4 +67,10 @@ def compute_fbank_librispeech():
if __name__ == "__main__": if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech() compute_fbank_librispeech()

View File

@ -2,10 +2,12 @@
""" """
This file computes fbank features of the musan dataset. This file computes fbank features of the musan dataset.
Its looks for manifests in the directory data/manifests Its looks for manifests in the directory data/manifests.
and generated fbank features are saved in data/fbank.
The generated fbank features are saved in data/fbank.
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@ -34,10 +36,10 @@ def compute_fbank_musan():
musan_cuts_path = output_dir / "cuts_musan.json.gz" musan_cuts_path = output_dir / "cuts_musan.json.gz"
if musan_cuts_path.is_file(): if musan_cuts_path.is_file():
print(f"{musan_cuts_path} already exists - skipping") logging.info(f"{musan_cuts_path} already exists - skipping")
return return
print("Extracting features for Musan") logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -63,4 +65,9 @@ def compute_fbank_musan():
if __name__ == "__main__": if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan() compute_fbank_musan()

View File

@ -2,10 +2,25 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
""" """
This file downloads librispeech LM files to data/lm This file downloads the following LibriSpeech LM files:
- 3-gram.pruned.1e-7.arpa.gz
- 4-gram.arpa.gz
- librispeech-vocab.txt
- librispeech-lexicon.txt
from http://www.openslr.org/resources/11
and save them in the user provided directory.
Files are not re-downloaded if they already exist.
Usage:
./local/download_lm.py --out-dir ./download/lm
""" """
import argparse
import gzip import gzip
import logging
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
@ -14,9 +29,17 @@ from lhotse.utils import urlretrieve_progress
from tqdm.auto import tqdm from tqdm.auto import tqdm
def download_lm(): def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", type=str, help="Output directory.")
args = parser.parse_args()
return args
def main(out_dir: str):
url = "http://www.openslr.org/resources/11" url = "http://www.openslr.org/resources/11"
target_dir = Path("data/lm") out_dir = Path(out_dir)
files_to_download = ( files_to_download = (
"3-gram.pruned.1e-7.arpa.gz", "3-gram.pruned.1e-7.arpa.gz",
@ -26,7 +49,7 @@ def download_lm():
) )
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
filename = target_dir / f filename = out_dir / f
if filename.is_file() is False: if filename.is_file() is False:
urlretrieve_progress( urlretrieve_progress(
f"{url}/{f}", f"{url}/{f}",
@ -34,17 +57,26 @@ def download_lm():
desc=f"Downloading {filename}", desc=f"Downloading {filename}",
) )
else: else:
print(f"{filename} already exists - skipping") logging.info(f"{filename} already exists - skipping")
if ".gz" in str(filename): if ".gz" in str(filename):
unzip_file = Path(os.path.splitext(filename)[0]) unzipped = Path(os.path.splitext(filename)[0])
if unzip_file.is_file() is False: if unzipped.is_file() is False:
with gzip.open(filename, "rb") as f_in: with gzip.open(filename, "rb") as f_in:
with open(unzip_file, "wb") as f_out: with open(unzipped, "wb") as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
else: else:
print(f"{unzip_file} already exist - skipping") logging.info(f"{unzipped} already exist - skipping")
if __name__ == "__main__": if __name__ == "__main__":
download_lm() formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(f"out_dir: {args.out_dir}")
main(out_dir=args.out_dir)

View File

@ -3,7 +3,7 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
""" """
This script takes as input a lexicon file "data/lang/lexicon.txt" This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following: consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt 1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
@ -20,8 +20,6 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.
""" """
import math import math
import re
import sys
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
@ -284,7 +282,9 @@ def lexicon_to_fst(
disambig_token = token2id["#0"] disambig_token = token2id["#0"]
disambig_word = word2id["#0"] disambig_word = word2id["#0"]
arcs = add_self_loops( arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word, arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
) )
final_state = next_state final_state = next_state
@ -301,7 +301,7 @@ def lexicon_to_fst(
def main(): def main():
out_dir = Path("data/lang") out_dir = Path("data/lang_phone")
lexicon_filename = out_dir / "lexicon.txt" lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL" sil_token = "SIL"
sil_prob = 0.5 sil_prob = 0.5

View File

@ -5,10 +5,10 @@
""" """
This script takes as inputs the following two files: This script takes as inputs the following two files:
- data/lang/bpe/bpe.model, - data/lang_bpe/bpe.model,
- data/lang/bpe/words.txt - data/lang_bpe/words.txt
and generates the following files in the directory data/lang/bpe: and generates the following files in the directory data/lang_bpe:
- lexicon.txt - lexicon.txt
- lexicon_disambig.txt - lexicon_disambig.txt
@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil(
disambig_token = token2id["#0"] disambig_token = token2id["#0"]
disambig_word = word2id["#0"] disambig_word = word2id["#0"]
arcs = add_self_loops( arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word, arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
) )
final_state = next_state final_state = next_state
@ -140,7 +142,7 @@ def generate_lexicon(
def main(): def main():
lang_dir = Path("data/lang/bpe") lang_dir = Path("data/lang_bpe")
model_file = lang_dir / "bpe.model" model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
@ -173,7 +175,9 @@ def main():
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil( L = lexicon_to_fst_no_sil(
lexicon, token2id=token_sym_table, word2id=word_sym_table, lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
) )
L_disambig = lexicon_to_fst_no_sil( L_disambig = lexicon_to_fst_no_sil(

View File

@ -14,18 +14,17 @@ and generates "data/lang/bpe/bep.model".
# #
# Please install a version >=0.1.96 # Please install a version >=0.1.96
import shutil
from pathlib import Path from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import shutil
def main(): def main():
model_type = "unigram" model_type = "unigram"
vocab_size = 5000 vocab_size = 5000
model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}" model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}"
train_text = "data/lang/bpe/train.txt" train_text = "data/lang_bpe/train.txt"
character_coverage = 1.0 character_coverage = 1.0
input_sentence_size = 100000000 input_sentence_size = 100000000
@ -53,7 +52,7 @@ def main():
sp = spm.SentencePieceProcessor(model_file=str(model_file)) sp = spm.SentencePieceProcessor(model_file=str(model_file))
vocab_size = sp.vocab_size() vocab_size = sp.vocab_size()
shutil.copyfile(model_file, "data/lang/bpe/bpe.model") shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -6,8 +6,38 @@ nj=15
stage=-1 stage=-1
stop_stage=100 stop_stage=100
. local/parse_options.sh || exit 1 # We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriSpeech
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
#
# - 3-gram.pruned.1e-7.arpa.gz
# - 3-gram.pruned.1e-7.arpa
# - 4-gram.arpa.gz
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
#
# - $do_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All generated files by this script are saved in "data"
mkdir -p data mkdir -p data
log() { log() {
@ -16,10 +46,11 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
} }
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM" log "stage -1: Download LM"
mkdir -p data/lm ./local/download_lm.py --out-dir=$dl_dir/lm
./local/download_lm.py
fi fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# If you have pre-downloaded it to /path/to/LibriSpeech, # If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink # you can create a symlink
# #
# ln -sfv /path/to/LibriSpeech data/ # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
# #
# The script checks that if if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
# lhotse download librispeech --full $dl_dir
# data/LibriSpeech/test-clean/.completed exists, fi
#
# it will not re-download it.
#
# The same goes for dev-clean, dev-other, test-other, train-clean-100
# train-clean-360, and train-other-500
mkdir -p data/LibriSpeech
lhotse download librispeech --full data
# If you have pre-downloaded it to /path/to/musan, # If you have pre-downloaded it to /path/to/musan,
# you can create a symlink # you can create a symlink
# #
# ln -sfv /path/to/musan data/ # ln -sfv /path/to/musan $dl_dir/
# #
# and create a file data/.musan_completed if [ ! -d $dl_dir/musan ]; then
# to avoid downloading it again lhotse download musan $dl_dir
if [ ! -f data/.musan_completed ]; then
lhotse download musan data
fi fi
fi fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare librispeech manifest" log "Stage 1: Prepare LibriSpeech manifest"
# We assume that you have downloaded the librispeech corpus # We assume that you have downloaded the LibriSpeech corpus
# to data/LibriSpeech # to $dl_dir/LibriSpeech
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# to data/musan # to data/musan
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare musan data/musan data/manifests lhotse prepare musan $dl_dir/musan data/manifests
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@ -84,24 +105,25 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang" log "Stage 5: Prepare phone based lang"
# TODO: add BPE based lang mkdir -p data/lang_phone
mkdir -p data/lang
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) | (echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - data/lm/librispeech-lexicon.txt | cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > data/lang/lexicon.txt sort | uniq > data/lang_phone/lexicon.txt
if [ ! -f data/lang/L_disambig.pt ]; then if [ ! -f data/lang_phone/L_disambig.pt ]; then
./local/prepare_lang.py ./local/prepare_lang.py
fi fi
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang" log "State 6: Prepare BPE based lang"
mkdir -p data/lang/bpe mkdir -p data/lang_bpe
cp data/lang/words.txt data/lang/bpe/ # We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt data/lang_bpe/
if [ ! -f data/lang/bpe/train.txt ]; then if [ ! -f data/lang_bpe/train.txt ]; then
log "Generate data for BPE training" log "Generate data for BPE training"
files=$( files=$(
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
) )
for f in ${files[@]}; do for f in ${files[@]}; do
cat $f | cut -d " " -f 2- cat $f | cut -d " " -f 2-
done > data/lang/bpe/train.txt done > data/lang_bpe/train.txt
fi fi
python3 ./local/train_bpe_model.py python3 ./local/train_bpe_model.py
if [ ! -f data/lang/bpe/L_disambig.pt ]; then if [ ! -f data/lang_bpe/L_disambig.pt ]; then
./local/prepare_lang_bpe.py ./local/prepare_lang_bpe.py
fi fi
fi fi
@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
# We assume you have install kaldilm, if not, please install # We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm # it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG # It is used in building HLG
python3 -m kaldilm \ python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \ --read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \ --disambig-symbol='#0' \
--max-order=3 \ --max-order=3 \
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring # It is used for LM rescoring
python3 -m kaldilm \ python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \ --read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \ --disambig-symbol='#0' \
--max-order=4 \ --max-order=4 \
data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi fi
fi fi

1
egs/librispeech/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -58,7 +58,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("tdnn_lstm_ctc/exp/"), "exp_dir": Path("tdnn_lstm_ctc/exp/"),
"lang_dir": Path("data/lang"), "lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 3, "subsampling_factor": 3,
@ -328,7 +328,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt")) HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -340,7 +340,7 @@ def main():
logging.info("Loading G_4_gram.fst.txt") logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.") logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f: with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.words["#0"] first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False) G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so # G.aux_labels is not needed in later computations, so

View File

@ -127,7 +127,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("tdnn_lstm_ctc/exp"), "exp_dir": Path("tdnn_lstm_ctc/exp"),
"lang_dir": Path("data/lang"), "lang_dir": Path("data/lang_phone"),
"lr": 1e-3, "lr": 1e-3,
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 5e-4, "weight_decay": 5e-4,

View File

@ -1,7 +1,8 @@
import logging import logging
import re import re
import sys
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple
import k2 import k2
import torch import torch
@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
continue continue
if len(a) < 2: if len(a) < 2:
print(f"Found bad line {line} in lexicon file {filename}") logging.info(
print("Every line is expected to contain at least 2 fields") f"Found bad line {line} in lexicon file {filename}"
)
logging.info(
"Every line is expected to contain at least 2 fields"
)
sys.exit(1) sys.exit(1)
word = a[0] word = a[0]
if word == "<eps>": if word == "<eps>":
print(f"Found bad line {line} in lexicon file {filename}") logging.info(
print("<eps> should not be a valid word") f"Found bad line {line} in lexicon file {filename}"
)
logging.info("<eps> should not be a valid word")
sys.exit(1) sys.exit(1)
tokens = a[1:] tokens = a[1:]
@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
class Lexicon(object): class Lexicon(object):
"""Phone based lexicon. """Phone based lexicon."""
TODO: Add BpeLexicon for BPE models.
"""
def __init__( def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
): ):
""" """
Args: Args:
@ -121,7 +127,9 @@ class Lexicon(object):
class BpeLexicon(Lexicon): class BpeLexicon(Lexicon):
def __init__( def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
): ):
""" """
Refer to the help information in Lexicon.__init__. Refer to the help information in Lexicon.__init__.