WIP: Refactoring

This commit is contained in:
Fangjun Kuang 2021-07-31 20:24:47 +08:00
parent c72a11ea1f
commit 1fa30998da
17 changed files with 195 additions and 122 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ path.sh
exp
exp*/
*.pt
download/

View File

@ -62,7 +62,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
@ -367,15 +367,13 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt"))
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
# HLG = k2.ctc_topo(4999).to(device)
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",

View File

@ -125,7 +125,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"weight_decay": 0.0,
"subsampling_factor": 4,

View File

@ -188,7 +188,7 @@ class Transformer(nn.Module):
encoder_mask: Tensor,
supervision: Supervisions = None,
graph_compiler: object = None,
token_ids: List[int] = None,
token_ids: List[List[int]] = None,
sos_id: Optional[int] = None,
eos_id: Optional[int] = None,
) -> Tensor:
@ -199,6 +199,7 @@ class Transformer(nn.Module):
supervision: Supervison in lhotse format, get from batch['supervisions']
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
, graph_compiler.words and graph_compiler.oov
token_ids: A list of lists. Each list contains word piece IDs for an utterance.
sos_id: sos token id
eos_id: eos token id
@ -210,7 +211,10 @@ class Transformer(nn.Module):
supervision, graph_compiler.lexicon.words, graph_compiler.oov
)
ys_in_pad, ys_out_pad = add_sos_eos(
batch_text, graph_compiler.L_inv, sos_id, eos_id,
batch_text,
graph_compiler.L_inv,
sos_id,
eos_id,
)
elif token_ids is not None:
_sos = torch.tensor([sos_id])
@ -225,7 +229,7 @@ class Transformer(nn.Module):
ys_out_pad = pad_list(ys_out, -1)
else:
raise ValueError("Invalid input for decoder self attetion")
raise ValueError("Invalid input for decoder self attention")
ys_in_pad = ys_in_pad.to(x.device)
ys_out_pad = ys_out_pad.to(x.device)
@ -284,7 +288,7 @@ class Transformer(nn.Module):
ys_in_pad = pad_list(ys_in, eos_id)
ys_out_pad = pad_list(ys_out, -1)
else:
raise ValueError("Invalid input for decoder self attetion")
raise ValueError("Invalid input for decoder self attention")
ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64)

View File

@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang or data/lang/bpe.
The language directory, e.g., data/lang_phone or data/lang_bpe.
Return:
An FSA representing HLG.
@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
return HLG
def phone_based_HLG():
if Path("data/lm/HLG.pt").is_file():
return
logging.info("Compiling phone based HLG")
HLG = compile_HLG("data/lang")
logging.info("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
def bpe_based_HLG():
if Path("data/lm/HLG_bpe.pt").is_file():
return
logging.info("Compiling BPE based HLG")
HLG = compile_HLG("data/lang/bpe")
logging.info("Saving HLG_bpe.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
def main():
phone_based_HLG()
bpe_based_HLG()
for d in ["data/lang_phone", "data/lang_bpe"]:
d = Path(d)
logging.info(f"Processing {d}")
if (d / "HLG.pt").is_file():
logging.info(f"{d}/HLG.pt already exists - skipping")
continue
HLG = compile_HLG(d)
logging.info(f"Saving HLG.pt to {d}")
torch.save(HLG.as_dict(), f"{d}/HLG.pt")
if __name__ == "__main__":

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python3
"""
This file computes fbank features of the librispeech dataset.
Its looks for manifests in the directory data/manifests
and generated fbank features are saved in data/fbank.
This file computes fbank features of the LibriSpeech dataset.
Its looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
@ -40,9 +42,9 @@ def compute_fbank_librispeech():
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
print(f"{partition} already exists - skipping.")
logging.info(f"{partition} already exists - skipping.")
continue
print("Processing", partition)
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
@ -65,4 +67,10 @@ def compute_fbank_librispeech():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech()

View File

@ -2,10 +2,12 @@
"""
This file computes fbank features of the musan dataset.
Its looks for manifests in the directory data/manifests
and generated fbank features are saved in data/fbank.
Its looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
@ -34,10 +36,10 @@ def compute_fbank_musan():
musan_cuts_path = output_dir / "cuts_musan.json.gz"
if musan_cuts_path.is_file():
print(f"{musan_cuts_path} already exists - skipping")
logging.info(f"{musan_cuts_path} already exists - skipping")
return
print("Extracting features for Musan")
logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -63,4 +65,9 @@ def compute_fbank_musan():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -2,10 +2,25 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This file downloads librispeech LM files to data/lm
This file downloads the following LibriSpeech LM files:
- 3-gram.pruned.1e-7.arpa.gz
- 4-gram.arpa.gz
- librispeech-vocab.txt
- librispeech-lexicon.txt
from http://www.openslr.org/resources/11
and save them in the user provided directory.
Files are not re-downloaded if they already exist.
Usage:
./local/download_lm.py --out-dir ./download/lm
"""
import argparse
import gzip
import logging
import os
import shutil
from pathlib import Path
@ -14,9 +29,17 @@ from lhotse.utils import urlretrieve_progress
from tqdm.auto import tqdm
def download_lm():
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", type=str, help="Output directory.")
args = parser.parse_args()
return args
def main(out_dir: str):
url = "http://www.openslr.org/resources/11"
target_dir = Path("data/lm")
out_dir = Path(out_dir)
files_to_download = (
"3-gram.pruned.1e-7.arpa.gz",
@ -26,7 +49,7 @@ def download_lm():
)
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
filename = target_dir / f
filename = out_dir / f
if filename.is_file() is False:
urlretrieve_progress(
f"{url}/{f}",
@ -34,17 +57,26 @@ def download_lm():
desc=f"Downloading {filename}",
)
else:
print(f"{filename} already exists - skipping")
logging.info(f"{filename} already exists - skipping")
if ".gz" in str(filename):
unzip_file = Path(os.path.splitext(filename)[0])
if unzip_file.is_file() is False:
unzipped = Path(os.path.splitext(filename)[0])
if unzipped.is_file() is False:
with gzip.open(filename, "rb") as f_in:
with open(unzip_file, "wb") as f_out:
with open(unzipped, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
else:
print(f"{unzip_file} already exist - skipping")
logging.info(f"{unzipped} already exist - skipping")
if __name__ == "__main__":
download_lm()
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(f"out_dir: {args.out_dir}")
main(out_dir=args.out_dir)

View File

@ -3,7 +3,7 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang/lexicon.txt"
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
@ -20,8 +20,6 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format.
"""
import math
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
@ -284,7 +282,9 @@ def lexicon_to_fst(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
@ -301,7 +301,7 @@ def lexicon_to_fst(
def main():
out_dir = Path("data/lang")
out_dir = Path("data/lang_phone")
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5

View File

@ -5,10 +5,10 @@
"""
This script takes as inputs the following two files:
- data/lang/bpe/bpe.model,
- data/lang/bpe/words.txt
- data/lang_bpe/bpe.model,
- data/lang_bpe/words.txt
and generates the following files in the directory data/lang/bpe:
and generates the following files in the directory data/lang_bpe:
- lexicon.txt
- lexicon_disambig.txt
@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
@ -140,7 +142,7 @@ def generate_lexicon(
def main():
lang_dir = Path("data/lang/bpe")
lang_dir = Path("data/lang_bpe")
model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
@ -173,7 +175,9 @@ def main():
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon, token2id=token_sym_table, word2id=word_sym_table,
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(

View File

@ -14,18 +14,17 @@ and generates "data/lang/bpe/bep.model".
#
# Please install a version >=0.1.96
import shutil
from pathlib import Path
import sentencepiece as spm
import shutil
def main():
model_type = "unigram"
vocab_size = 5000
model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}"
train_text = "data/lang/bpe/train.txt"
model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}"
train_text = "data/lang_bpe/train.txt"
character_coverage = 1.0
input_sentence_size = 100000000
@ -53,7 +52,7 @@ def main():
sp = spm.SentencePieceProcessor(model_file=str(model_file))
vocab_size = sp.vocab_size()
shutil.copyfile(model_file, "data/lang/bpe/bpe.model")
shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
if __name__ == "__main__":

View File

@ -6,8 +6,38 @@ nj=15
stage=-1
stop_stage=100
. local/parse_options.sh || exit 1
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriSpeech
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
#
# - 3-gram.pruned.1e-7.arpa.gz
# - 3-gram.pruned.1e-7.arpa
# - 4-gram.arpa.gz
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
#
# - $do_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All generated files by this script are saved in "data"
mkdir -p data
log() {
@ -16,10 +46,11 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM"
mkdir -p data/lm
./local/download_lm.py
./local/download_lm.py --out-dir=$dl_dir/lm
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink
#
# ln -sfv /path/to/LibriSpeech data/
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
#
# The script checks that if
#
# data/LibriSpeech/test-clean/.completed exists,
#
# it will not re-download it.
#
# The same goes for dev-clean, dev-other, test-other, train-clean-100
# train-clean-360, and train-other-500
mkdir -p data/LibriSpeech
lhotse download librispeech --full data
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
lhotse download librispeech --full $dl_dir
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan data/
# ln -sfv /path/to/musan $dl_dir/
#
# and create a file data/.musan_completed
# to avoid downloading it again
if [ ! -f data/.musan_completed ]; then
lhotse download musan data
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare librispeech manifest"
# We assume that you have downloaded the librispeech corpus
# to data/LibriSpeech
log "Stage 1: Prepare LibriSpeech manifest"
# We assume that you have downloaded the LibriSpeech corpus
# to $dl_dir/LibriSpeech
mkdir -p data/manifests
lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
lhotse prepare musan data/musan data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@ -84,24 +105,25 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
# TODO: add BPE based lang
mkdir -p data/lang
mkdir -p data/lang_phone
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - data/lm/librispeech-lexicon.txt |
sort | uniq > data/lang/lexicon.txt
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > data/lang_phone/lexicon.txt
if [ ! -f data/lang/L_disambig.pt ]; then
if [ ! -f data/lang_phone/L_disambig.pt ]; then
./local/prepare_lang.py
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang"
mkdir -p data/lang/bpe
cp data/lang/words.txt data/lang/bpe/
mkdir -p data/lang_bpe
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt data/lang_bpe/
if [ ! -f data/lang/bpe/train.txt ]; then
if [ ! -f data/lang_bpe/train.txt ]; then
log "Generate data for BPE training"
files=$(
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > data/lang/bpe/train.txt
done > data/lang_bpe/train.txt
fi
python3 ./local/train_bpe_model.py
if [ ! -f data/lang/bpe/L_disambig.pt ]; then
if [ ! -f data/lang_bpe/L_disambig.pt ]; then
./local/prepare_lang_bpe.py
fi
fi
@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi
fi

1
egs/librispeech/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -58,7 +58,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
"lang_dir": Path("data/lang"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"subsampling_factor": 3,
@ -328,7 +328,7 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -340,7 +340,7 @@ def main():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.words["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so

View File

@ -127,7 +127,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp"),
"lang_dir": Path("data/lang"),
"lang_dir": Path("data/lang_phone"),
"lr": 1e-3,
"feature_dim": 80,
"weight_decay": 5e-4,

View File

@ -1,7 +1,8 @@
import logging
import re
import sys
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Tuple
import k2
import torch
@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
continue
if len(a) < 2:
print(f"Found bad line {line} in lexicon file {filename}")
print("Every line is expected to contain at least 2 fields")
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info(
"Every line is expected to contain at least 2 fields"
)
sys.exit(1)
word = a[0]
if word == "<eps>":
print(f"Found bad line {line} in lexicon file {filename}")
print("<eps> should not be a valid word")
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info("<eps> should not be a valid word")
sys.exit(1)
tokens = a[1:]
@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
class Lexicon(object):
"""Phone based lexicon.
TODO: Add BpeLexicon for BPE models.
"""
"""Phone based lexicon."""
def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Args:
@ -121,7 +127,9 @@ class Lexicon(object):
class BpeLexicon(Lexicon):
def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Refer to the help information in Lexicon.__init__.