mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
WIP: Refactoring
This commit is contained in:
parent
c72a11ea1f
commit
1fa30998da
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ path.sh
|
||||
exp
|
||||
exp*/
|
||||
*.pt
|
||||
download/
|
||||
|
@ -62,7 +62,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang/bpe"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
@ -367,15 +367,13 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt"))
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
# HLG = k2.ctc_topo(4999).to(device)
|
||||
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
|
@ -125,7 +125,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang/bpe"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 0.0,
|
||||
"subsampling_factor": 4,
|
||||
|
@ -188,7 +188,7 @@ class Transformer(nn.Module):
|
||||
encoder_mask: Tensor,
|
||||
supervision: Supervisions = None,
|
||||
graph_compiler: object = None,
|
||||
token_ids: List[int] = None,
|
||||
token_ids: List[List[int]] = None,
|
||||
sos_id: Optional[int] = None,
|
||||
eos_id: Optional[int] = None,
|
||||
) -> Tensor:
|
||||
@ -199,6 +199,7 @@ class Transformer(nn.Module):
|
||||
supervision: Supervison in lhotse format, get from batch['supervisions']
|
||||
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
|
||||
, graph_compiler.words and graph_compiler.oov
|
||||
token_ids: A list of lists. Each list contains word piece IDs for an utterance.
|
||||
sos_id: sos token id
|
||||
eos_id: eos token id
|
||||
|
||||
@ -210,7 +211,10 @@ class Transformer(nn.Module):
|
||||
supervision, graph_compiler.lexicon.words, graph_compiler.oov
|
||||
)
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(
|
||||
batch_text, graph_compiler.L_inv, sos_id, eos_id,
|
||||
batch_text,
|
||||
graph_compiler.L_inv,
|
||||
sos_id,
|
||||
eos_id,
|
||||
)
|
||||
elif token_ids is not None:
|
||||
_sos = torch.tensor([sos_id])
|
||||
@ -225,7 +229,7 @@ class Transformer(nn.Module):
|
||||
ys_out_pad = pad_list(ys_out, -1)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid input for decoder self attetion")
|
||||
raise ValueError("Invalid input for decoder self attention")
|
||||
|
||||
ys_in_pad = ys_in_pad.to(x.device)
|
||||
ys_out_pad = ys_out_pad.to(x.device)
|
||||
@ -284,7 +288,7 @@ class Transformer(nn.Module):
|
||||
ys_in_pad = pad_list(ys_in, eos_id)
|
||||
ys_out_pad = pad_list(ys_out, -1)
|
||||
else:
|
||||
raise ValueError("Invalid input for decoder self attetion")
|
||||
raise ValueError("Invalid input for decoder self attention")
|
||||
|
||||
ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64)
|
||||
ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64)
|
||||
|
@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang or data/lang/bpe.
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
return HLG
|
||||
|
||||
|
||||
def phone_based_HLG():
|
||||
if Path("data/lm/HLG.pt").is_file():
|
||||
return
|
||||
|
||||
logging.info("Compiling phone based HLG")
|
||||
HLG = compile_HLG("data/lang")
|
||||
|
||||
logging.info("Saving HLG.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
|
||||
|
||||
|
||||
def bpe_based_HLG():
|
||||
if Path("data/lm/HLG_bpe.pt").is_file():
|
||||
return
|
||||
|
||||
logging.info("Compiling BPE based HLG")
|
||||
HLG = compile_HLG("data/lang/bpe")
|
||||
logging.info("Saving HLG_bpe.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
|
||||
|
||||
|
||||
def main():
|
||||
phone_based_HLG()
|
||||
bpe_based_HLG()
|
||||
for d in ["data/lang_phone", "data/lang_bpe"]:
|
||||
d = Path(d)
|
||||
logging.info(f"Processing {d}")
|
||||
|
||||
if (d / "HLG.pt").is_file():
|
||||
logging.info(f"{d}/HLG.pt already exists - skipping")
|
||||
continue
|
||||
|
||||
HLG = compile_HLG(d)
|
||||
logging.info(f"Saving HLG.pt to {d}")
|
||||
torch.save(HLG.as_dict(), f"{d}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,11 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file computes fbank features of the librispeech dataset.
|
||||
Its looks for manifests in the directory data/manifests
|
||||
and generated fbank features are saved in data/fbank.
|
||||
This file computes fbank features of the LibriSpeech dataset.
|
||||
Its looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@ -40,9 +42,9 @@ def compute_fbank_librispeech():
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"cuts_{partition}.json.gz").is_file():
|
||||
print(f"{partition} already exists - skipping.")
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
print("Processing", partition)
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
@ -65,4 +67,10 @@ def compute_fbank_librispeech():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_librispeech()
|
||||
|
@ -2,10 +2,12 @@
|
||||
|
||||
"""
|
||||
This file computes fbank features of the musan dataset.
|
||||
Its looks for manifests in the directory data/manifests
|
||||
and generated fbank features are saved in data/fbank.
|
||||
Its looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@ -34,10 +36,10 @@ def compute_fbank_musan():
|
||||
musan_cuts_path = output_dir / "cuts_musan.json.gz"
|
||||
|
||||
if musan_cuts_path.is_file():
|
||||
print(f"{musan_cuts_path} already exists - skipping")
|
||||
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||
return
|
||||
|
||||
print("Extracting features for Musan")
|
||||
logging.info("Extracting features for Musan")
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -63,4 +65,9 @@ def compute_fbank_musan():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_musan()
|
||||
|
@ -2,10 +2,25 @@
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file downloads librispeech LM files to data/lm
|
||||
This file downloads the following LibriSpeech LM files:
|
||||
|
||||
- 3-gram.pruned.1e-7.arpa.gz
|
||||
- 4-gram.arpa.gz
|
||||
- librispeech-vocab.txt
|
||||
- librispeech-lexicon.txt
|
||||
|
||||
from http://www.openslr.org/resources/11
|
||||
and save them in the user provided directory.
|
||||
|
||||
Files are not re-downloaded if they already exist.
|
||||
|
||||
Usage:
|
||||
./local/download_lm.py --out-dir ./download/lm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@ -14,9 +29,17 @@ from lhotse.utils import urlretrieve_progress
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def download_lm():
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out-dir", type=str, help="Output directory.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(out_dir: str):
|
||||
url = "http://www.openslr.org/resources/11"
|
||||
target_dir = Path("data/lm")
|
||||
out_dir = Path(out_dir)
|
||||
|
||||
files_to_download = (
|
||||
"3-gram.pruned.1e-7.arpa.gz",
|
||||
@ -26,7 +49,7 @@ def download_lm():
|
||||
)
|
||||
|
||||
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
|
||||
filename = target_dir / f
|
||||
filename = out_dir / f
|
||||
if filename.is_file() is False:
|
||||
urlretrieve_progress(
|
||||
f"{url}/{f}",
|
||||
@ -34,17 +57,26 @@ def download_lm():
|
||||
desc=f"Downloading {filename}",
|
||||
)
|
||||
else:
|
||||
print(f"{filename} already exists - skipping")
|
||||
logging.info(f"{filename} already exists - skipping")
|
||||
|
||||
if ".gz" in str(filename):
|
||||
unzip_file = Path(os.path.splitext(filename)[0])
|
||||
if unzip_file.is_file() is False:
|
||||
unzipped = Path(os.path.splitext(filename)[0])
|
||||
if unzipped.is_file() is False:
|
||||
with gzip.open(filename, "rb") as f_in:
|
||||
with open(unzip_file, "wb") as f_out:
|
||||
with open(unzipped, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
else:
|
||||
print(f"{unzip_file} already exist - skipping")
|
||||
logging.info(f"{unzipped} already exist - skipping")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_lm()
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
logging.info(f"out_dir: {args.out_dir}")
|
||||
|
||||
main(out_dir=args.out_dir)
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang/lexicon.txt"
|
||||
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||
consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
@ -20,8 +20,6 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@ -284,7 +282,9 @@ def lexicon_to_fst(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
@ -301,7 +301,7 @@ def lexicon_to_fst(
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("data/lang")
|
||||
out_dir = Path("data/lang_phone")
|
||||
lexicon_filename = out_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
@ -5,10 +5,10 @@
|
||||
"""
|
||||
This script takes as inputs the following two files:
|
||||
|
||||
- data/lang/bpe/bpe.model,
|
||||
- data/lang/bpe/words.txt
|
||||
- data/lang_bpe/bpe.model,
|
||||
- data/lang_bpe/words.txt
|
||||
|
||||
and generates the following files in the directory data/lang/bpe:
|
||||
and generates the following files in the directory data/lang_bpe:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
@ -140,7 +142,7 @@ def generate_lexicon(
|
||||
|
||||
|
||||
def main():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
lang_dir = Path("data/lang_bpe")
|
||||
model_file = lang_dir / "bpe.model"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
@ -173,7 +175,9 @@ def main():
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon, token2id=token_sym_table, word2id=word_sym_table,
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
|
@ -14,18 +14,17 @@ and generates "data/lang/bpe/bep.model".
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
import shutil
|
||||
|
||||
|
||||
def main():
|
||||
model_type = "unigram"
|
||||
vocab_size = 5000
|
||||
model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}"
|
||||
train_text = "data/lang/bpe/train.txt"
|
||||
model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}"
|
||||
train_text = "data/lang_bpe/train.txt"
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
@ -53,7 +52,7 @@ def main():
|
||||
sp = spm.SentencePieceProcessor(model_file=str(model_file))
|
||||
vocab_size = sp.vocab_size()
|
||||
|
||||
shutil.copyfile(model_file, "data/lang/bpe/bpe.model")
|
||||
shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -6,8 +6,38 @@ nj=15
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
. local/parse_options.sh || exit 1
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/LibriSpeech
|
||||
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
|
||||
# You can download them from https://www.openslr.org/12
|
||||
#
|
||||
# - $dl_dir/lm
|
||||
# This directory contains the following files downloaded from
|
||||
# http://www.openslr.org/resources/11
|
||||
#
|
||||
# - 3-gram.pruned.1e-7.arpa.gz
|
||||
# - 3-gram.pruned.1e-7.arpa
|
||||
# - 4-gram.arpa.gz
|
||||
# - 4-gram.arpa
|
||||
# - librispeech-vocab.txt
|
||||
# - librispeech-lexicon.txt
|
||||
#
|
||||
# - $do_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
|
||||
# All generated files by this script are saved in "data"
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
@ -16,10 +46,11 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
mkdir -p data/lm
|
||||
./local/download_lm.py
|
||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
# If you have pre-downloaded it to /path/to/LibriSpeech,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/LibriSpeech data/
|
||||
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
|
||||
#
|
||||
# The script checks that if
|
||||
#
|
||||
# data/LibriSpeech/test-clean/.completed exists,
|
||||
#
|
||||
# it will not re-download it.
|
||||
#
|
||||
# The same goes for dev-clean, dev-other, test-other, train-clean-100
|
||||
# train-clean-360, and train-other-500
|
||||
|
||||
mkdir -p data/LibriSpeech
|
||||
lhotse download librispeech --full data
|
||||
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
|
||||
lhotse download librispeech --full $dl_dir
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan data/
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
# and create a file data/.musan_completed
|
||||
# to avoid downloading it again
|
||||
if [ ! -f data/.musan_completed ]; then
|
||||
lhotse download musan data
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare librispeech manifest"
|
||||
# We assume that you have downloaded the librispeech corpus
|
||||
# to data/LibriSpeech
|
||||
log "Stage 1: Prepare LibriSpeech manifest"
|
||||
# We assume that you have downloaded the LibriSpeech corpus
|
||||
# to $dl_dir/LibriSpeech
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests
|
||||
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare musan data/musan data/manifests
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
@ -84,24 +105,25 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
# TODO: add BPE based lang
|
||||
mkdir -p data/lang
|
||||
mkdir -p data/lang_phone
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - data/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > data/lang/lexicon.txt
|
||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > data/lang_phone/lexicon.txt
|
||||
|
||||
if [ ! -f data/lang/L_disambig.pt ]; then
|
||||
if [ ! -f data/lang_phone/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "State 6: Prepare BPE based lang"
|
||||
mkdir -p data/lang/bpe
|
||||
cp data/lang/words.txt data/lang/bpe/
|
||||
mkdir -p data/lang_bpe
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone/words.txt data/lang_bpe/
|
||||
|
||||
if [ ! -f data/lang/bpe/train.txt ]; then
|
||||
if [ ! -f data/lang_bpe/train.txt ]; then
|
||||
log "Generate data for BPE training"
|
||||
files=$(
|
||||
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||
@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
)
|
||||
for f in ${files[@]}; do
|
||||
cat $f | cut -d " " -f 2-
|
||||
done > data/lang/bpe/train.txt
|
||||
done > data/lang_bpe/train.txt
|
||||
fi
|
||||
|
||||
python3 ./local/train_bpe_model.py
|
||||
|
||||
if [ ! -f data/lang/bpe/L_disambig.pt ]; then
|
||||
if [ ! -f data/lang_bpe/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py
|
||||
fi
|
||||
fi
|
||||
@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang/words.txt" \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
||||
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
|
||||
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
|
||||
# It is used for LM rescoring
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang/words.txt" \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=4 \
|
||||
data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
|
||||
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
|
1
egs/librispeech/ASR/shared
Symbolic link
1
egs/librispeech/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
@ -58,7 +58,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 3,
|
||||
@ -328,7 +328,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
|
||||
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -340,7 +340,7 @@ def main():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.words["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
|
@ -127,7 +127,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lr": 1e-3,
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 5e-4,
|
||||
|
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
|
||||
continue
|
||||
|
||||
if len(a) < 2:
|
||||
print(f"Found bad line {line} in lexicon file {filename}")
|
||||
print("Every line is expected to contain at least 2 fields")
|
||||
logging.info(
|
||||
f"Found bad line {line} in lexicon file {filename}"
|
||||
)
|
||||
logging.info(
|
||||
"Every line is expected to contain at least 2 fields"
|
||||
)
|
||||
sys.exit(1)
|
||||
word = a[0]
|
||||
if word == "<eps>":
|
||||
print(f"Found bad line {line} in lexicon file {filename}")
|
||||
print("<eps> should not be a valid word")
|
||||
logging.info(
|
||||
f"Found bad line {line} in lexicon file {filename}"
|
||||
)
|
||||
logging.info("<eps> should not be a valid word")
|
||||
sys.exit(1)
|
||||
|
||||
tokens = a[1:]
|
||||
@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
|
||||
|
||||
|
||||
class Lexicon(object):
|
||||
"""Phone based lexicon.
|
||||
|
||||
TODO: Add BpeLexicon for BPE models.
|
||||
"""
|
||||
"""Phone based lexicon."""
|
||||
|
||||
def __init__(
|
||||
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
self,
|
||||
lang_dir: Path,
|
||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -121,7 +127,9 @@ class Lexicon(object):
|
||||
|
||||
class BpeLexicon(Lexicon):
|
||||
def __init__(
|
||||
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
self,
|
||||
lang_dir: Path,
|
||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
):
|
||||
"""
|
||||
Refer to the help information in Lexicon.__init__.
|
||||
|
Loading…
x
Reference in New Issue
Block a user