Add data prep for slu recipe

This commit is contained in:
Xinyuan Li 2023-09-23 13:36:52 -04:00
parent bbb03f7962
commit 2b579d26f6
8 changed files with 756 additions and 3 deletions

14
egs/slu/README.md Executable file
View File

@ -0,0 +1,14 @@
## Yesno recipe
This is the simplest ASR recipe in `icefall`.
It can be run on CPU and takes less than 30 seconds to
get the following WER:
```
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
```
Please refer to
<https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/yesno/index.html>
for detailed instructions.

136
egs/slu/local/compile_hlg.py Executable file
View File

@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
logging.info("Loading G.fst.txt")
with open(lang_dir / "G.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""
This file computes fbank features of the Fluent Speech Commands dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a
# lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_slu():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
# This dataset is rather small, so we use only one job
num_jobs = min(1, os.cpu_count())
num_mel_bins = 23
dataset_parts = (
"train",
"valid",
"test",
)
prefix = "slu"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
if cuts_file.is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 1, # use one job
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(cuts_file)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_slu()

View File

@ -0,0 +1,47 @@
import pandas, argparse
from tqdm import tqdm
def generate_lexicon(corpus_dir, lm_dir):
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0)
vocab_transcript = set()
vocab_frames = set()
transcripts = data['transcription'].tolist()
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist()))
for transcript in tqdm(transcripts):
for word in transcript.split():
vocab_transcript.add(word)
for frame in tqdm(frames):
for word in frame:
vocab_frames.add('_'.join(word.split()))
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file:
lexicon_transcript_file.write("<UNK> 1" + '\n')
lexicon_transcript_file.write("<s> 2" + '\n')
lexicon_transcript_file.write("</s> 0" + '\n')
id = 3
for vocab in vocab_transcript:
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n')
id += 1
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file:
lexicon_frames_file.write("<UNK> 1" + '\n')
lexicon_frames_file.write("<s> 2" + '\n')
lexicon_frames_file.write("</s> 0" + '\n')
id = 3
for vocab in vocab_frames:
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n')
id += 1
parser = argparse.ArgumentParser()
parser.add_argument('corpus_dir')
parser.add_argument('lm_dir')
def main():
args = parser.parse_args()
generate_lexicon(args.corpus_dir, args.lm_dir)
main()

364
egs/slu/local/prepare_lang.py Executable file
View File

@ -0,0 +1,364 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "!SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
# assert token2id["<eps>"] == 0
# assert word2id["<eps>"] == 0
eps = 0
sil_token = word2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [word2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = word2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def main():
out_dir = Path("data/lm")
lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"]
names = ["frames", "transcript"]
sil_token = "!SIL"
sil_prob = 0.5
for name, lexicon_filename in zip(names, lexicon_filenames):
lexicon = read_lexicon(lexicon_filename)
tokens = get_words(lexicon)
words = get_words(lexicon)
new_lexicon = []
for lexicon_item in lexicon:
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
lexicon = new_lexicon
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
tokens = ["<eps>"] + tokens
words = ['eps'] + words + ["#0", "!SIL"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
main()

100
egs/slu/prepare.sh Executable file
View File

@ -0,0 +1,100 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=3
stop_stage=5
data_dir=/home/xli257/slu/fluent_speech_commands_dataset
lang_dir=data/lang_phone
lm_dir=data/lm
. shared/parse_options.sh || exit 1
mkdir -p $lang_dir
mkdir -p $lm_dir
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "data_dir: $data_dir"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare slu manifest"
mkdir -p data/manifests
lhotse prepare slu $data_dir data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute fbank for yesno"
mkdir -p data/fbank
python ./local/compute_fbank_slu.py
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare lang"
# NOTE: "<UNK> SIL" is added for implementation convenience
# as the graph compiler code requires that there is a OOV word
# in the lexicon.
python ./local/generate_lexicon.py $data_dir $lm_dir
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Train LM"
# We use a unigram G
./shared/make_kn_lm.py \
-ngram-order 1 \
-text $lm_dir/words_transcript.txt \
-lm $lm_dir/G_transcript.arpa
./shared/make_kn_lm.py \
-ngram-order 1 \
-text $lm_dir/words_frames.txt \
-lm $lm_dir/G_frames.arpa
python ./local/prepare_lang.py
if [ ! -f $lm_dir/G_transcript.fst.txt ]; then
python -m kaldilm \
--read-symbol-table="$lm_dir/words_transcript.txt" \
$lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt
fi
if [ ! -f $lm_dir/G_frames.fst.txt ]; then
python -m kaldilm \
--read-symbol-table="$lm_dir/words_frames.txt" \
$lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt
fi
mkdir -p $lm_dir/frames
mkdir -p $lm_dir/transcript
chmod -R +777 .
for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt;
do
j=${i//"_frames"/}
mv "$lm_dir/$i" $lm_dir/frames/$j
done
for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt;
do
j=${i//"_transcript"/}
mv "$lm_dir/$i" $lm_dir/transcript/$j
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compile HLG"
./local/compile_hlg.py --lang-dir $lm_dir/frames
./local/compile_hlg.py --lang-dir $lm_dir/transcript
fi

1
egs/slu/shared Symbolic link
View File

@ -0,0 +1 @@
../../icefall/shared/

View File

@ -33,7 +33,7 @@ parser.add_argument(
"-ngram-order",
type=int,
default=4,
choices=[2, 3, 4, 5, 6, 7],
choices=[1, 2, 3, 4, 5, 6, 7],
help="Order of n-gram",
)
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
@ -105,7 +105,7 @@ class NgramCounts:
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
def __init__(self, ngram_order, bos_symbol="<s>", eos_symbol="</s>"):
assert ngram_order >= 2
# assert ngram_order >= 2
self.ngram_order = ngram_order
self.bos_symbol = bos_symbol
@ -169,7 +169,7 @@ class NgramCounts:
with open(filename, encoding=default_encoding) as fp:
for line in fp:
line = line.strip(strip_chars)
self.add_raw_counts_from_line(line)
self.add_raw_counts_from_line(line.split()[0])
lines_processed += 1
if lines_processed == 0 or args.verbose > 0:
print(