mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Add data prep for slu recipe
This commit is contained in:
parent
bbb03f7962
commit
2b579d26f6
14
egs/slu/README.md
Executable file
14
egs/slu/README.md
Executable file
@ -0,0 +1,14 @@
|
||||
## Yesno recipe
|
||||
|
||||
This is the simplest ASR recipe in `icefall`.
|
||||
|
||||
It can be run on CPU and takes less than 30 seconds to
|
||||
get the following WER:
|
||||
|
||||
```
|
||||
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
```
|
||||
|
||||
Please refer to
|
||||
<https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/yesno/index.html>
|
||||
for detailed instructions.
|
136
egs/slu/local/compile_hlg.py
Executable file
136
egs/slu/local/compile_hlg.py
Executable file
@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
logging.info("Loading G.fst.txt")
|
||||
with open(lang_dir / "G.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
91
egs/slu/local/compute_fbank_slu.py
Executable file
91
egs/slu/local/compute_fbank_slu.py
Executable file
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file computes fbank features of the Fluent Speech Commands dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or it wastes a
|
||||
# lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_slu():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
# This dataset is rather small, so we use only one job
|
||||
num_jobs = min(1, os.cpu_count())
|
||||
num_mel_bins = 23
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
"valid",
|
||||
"test",
|
||||
)
|
||||
prefix = "slu"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if cuts_file.is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 1, # use one job
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(cuts_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_slu()
|
47
egs/slu/local/generate_lexicon.py
Executable file
47
egs/slu/local/generate_lexicon.py
Executable file
@ -0,0 +1,47 @@
|
||||
import pandas, argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
def generate_lexicon(corpus_dir, lm_dir):
|
||||
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0)
|
||||
vocab_transcript = set()
|
||||
vocab_frames = set()
|
||||
transcripts = data['transcription'].tolist()
|
||||
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist()))
|
||||
|
||||
for transcript in tqdm(transcripts):
|
||||
for word in transcript.split():
|
||||
vocab_transcript.add(word)
|
||||
|
||||
for frame in tqdm(frames):
|
||||
for word in frame:
|
||||
vocab_frames.add('_'.join(word.split()))
|
||||
|
||||
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file:
|
||||
lexicon_transcript_file.write("<UNK> 1" + '\n')
|
||||
lexicon_transcript_file.write("<s> 2" + '\n')
|
||||
lexicon_transcript_file.write("</s> 0" + '\n')
|
||||
id = 3
|
||||
for vocab in vocab_transcript:
|
||||
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n')
|
||||
id += 1
|
||||
|
||||
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file:
|
||||
lexicon_frames_file.write("<UNK> 1" + '\n')
|
||||
lexicon_frames_file.write("<s> 2" + '\n')
|
||||
lexicon_frames_file.write("</s> 0" + '\n')
|
||||
id = 3
|
||||
for vocab in vocab_frames:
|
||||
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n')
|
||||
id += 1
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('corpus_dir')
|
||||
parser.add_argument('lm_dir')
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_lexicon(args.corpus_dir, args.lm_dir)
|
||||
|
||||
main()
|
364
egs/slu/local/prepare_lang.py
Executable file
364
egs/slu/local/prepare_lang.py
Executable file
@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||
consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
|
||||
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||
|
||||
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Note:
|
||||
No need to implement `read_mapping` as it can be done
|
||||
through :func:`k2.SymbolTable.from_file`.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||
"""Get tokens from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique tokens.
|
||||
"""
|
||||
ans = set()
|
||||
for _, tokens in lexicon:
|
||||
ans.update(tokens)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def get_words(lexicon: Lexicon) -> List[str]:
|
||||
"""Get words from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique words.
|
||||
"""
|
||||
ans = set()
|
||||
for word, _ in lexicon:
|
||||
ans.add(word)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||
at the ends of tokens to ensure that all pronunciations are different,
|
||||
and that none is a prefix of another.
|
||||
|
||||
See also add_lex_disambig.pl from kaldi.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is returned by :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
|
||||
- The output lexicon with disambiguation symbols
|
||||
- The ID of the max disambiguation symbol that appears
|
||||
in the lexicon
|
||||
"""
|
||||
|
||||
# (1) Work out the count of each token-sequence in the
|
||||
# lexicon.
|
||||
count = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
count[" ".join(tokens)] += 1
|
||||
|
||||
# (2) For each left sub-sequence of each token-sequence, note down
|
||||
# that it exists (for identifying prefixes of longer strings).
|
||||
issubseq = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
tokens = tokens.copy()
|
||||
tokens.pop()
|
||||
while tokens:
|
||||
issubseq[" ".join(tokens)] = 1
|
||||
tokens.pop()
|
||||
|
||||
# (3) For each entry in the lexicon:
|
||||
# if the token sequence is unique and is not a
|
||||
# prefix of another word, no disambig symbol.
|
||||
# Else output #1, or #2, #3, ... if the same token-seq
|
||||
# has already been assigned a disambig symbol.
|
||||
ans = []
|
||||
|
||||
# We start with #1 since #0 has its own purpose
|
||||
first_allowed_disambig = 1
|
||||
max_disambig = first_allowed_disambig - 1
|
||||
last_used_disambig_symbol_of = defaultdict(int)
|
||||
|
||||
for word, tokens in lexicon:
|
||||
tokenseq = " ".join(tokens)
|
||||
assert tokenseq != ""
|
||||
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||
ans.append((word, tokens))
|
||||
continue
|
||||
|
||||
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||
if cur_disambig == 0:
|
||||
cur_disambig = first_allowed_disambig
|
||||
else:
|
||||
cur_disambig += 1
|
||||
|
||||
if cur_disambig > max_disambig:
|
||||
max_disambig = cur_disambig
|
||||
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||
tokenseq += f" #{cur_disambig}"
|
||||
ans.append((word, tokenseq.split()))
|
||||
return ans, max_disambig
|
||||
|
||||
|
||||
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||
|
||||
Args:
|
||||
symbols:
|
||||
A list of unique symbols.
|
||||
Returns:
|
||||
A dict containing the mapping between symbols and IDs.
|
||||
"""
|
||||
return {sym: i for i, sym in enumerate(symbols)}
|
||||
|
||||
|
||||
def add_self_loops(
|
||||
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||
) -> List[List[Any]]:
|
||||
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||
through it. They are added on each state with non-epsilon output symbols
|
||||
on at least one arc out of the state.
|
||||
|
||||
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||
This function uses k2 style FSTs and it does not need to add self-loops
|
||||
to the final state.
|
||||
|
||||
The input label of a self-loop is `disambig_token`, while the output
|
||||
label is `disambig_word`.
|
||||
|
||||
Args:
|
||||
arcs:
|
||||
A list-of-list. The sublist contains
|
||||
`[src_state, dest_state, label, aux_label, score]`
|
||||
disambig_token:
|
||||
It is the token ID of the symbol `#0`.
|
||||
disambig_word:
|
||||
It is the word ID of the symbol `#0`.
|
||||
|
||||
Return:
|
||||
Return new `arcs` containing self-loops.
|
||||
"""
|
||||
states_needs_self_loops = set()
|
||||
for arc in arcs:
|
||||
src, dst, ilabel, olabel, score = arc
|
||||
if olabel != 0:
|
||||
states_needs_self_loops.add(src)
|
||||
|
||||
ans = []
|
||||
for s in states_needs_self_loops:
|
||||
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||
|
||||
return arcs + ans
|
||||
|
||||
|
||||
def lexicon_to_fst(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
sil_token: str = "!SIL",
|
||||
sil_prob: float = 0.5,
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||
the beginning and end of each word.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
sil_token:
|
||||
The silence token.
|
||||
sil_prob:
|
||||
The probability for adding a silence at the beginning and end
|
||||
of the word.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||
# CAUTION: we use score, i.e, negative cost.
|
||||
sil_score = math.log(sil_prob)
|
||||
no_sil_score = math.log(1.0 - sil_prob)
|
||||
|
||||
start_state = 0
|
||||
loop_state = 1 # words enter and leave from here
|
||||
sil_state = 2 # words terminate here when followed by silence; this state
|
||||
# has a silence transition to loop_state.
|
||||
next_state = 3 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
# assert token2id["<eps>"] == 0
|
||||
# assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
sil_token = word2id[sil_token]
|
||||
|
||||
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||
arcs.append([sil_state, loop_state, sil_token, eps, 0])
|
||||
|
||||
for word, tokens in lexicon:
|
||||
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
tokens = [word2id[i] for i in tokens]
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, tokens[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last token of this word
|
||||
# It has two out-going arcs, one to the loop state,
|
||||
# the other one to the sil_state.
|
||||
i = len(tokens) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
|
||||
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = word2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("data/lm")
|
||||
lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"]
|
||||
names = ["frames", "transcript"]
|
||||
sil_token = "!SIL"
|
||||
sil_prob = 0.5
|
||||
|
||||
for name, lexicon_filename in zip(names, lexicon_filenames):
|
||||
lexicon = read_lexicon(lexicon_filename)
|
||||
tokens = get_words(lexicon)
|
||||
words = get_words(lexicon)
|
||||
new_lexicon = []
|
||||
for lexicon_item in lexicon:
|
||||
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
|
||||
lexicon = new_lexicon
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in tokens
|
||||
tokens.append(f"#{i}")
|
||||
|
||||
tokens = ["<eps>"] + tokens
|
||||
words = ['eps'] + words + ["#0", "!SIL"]
|
||||
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
|
||||
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
|
||||
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
token2id=word2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst(
|
||||
lexicon_disambig,
|
||||
token2id=word2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
|
||||
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(out_dir / "L.png", title="L")
|
||||
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
|
||||
|
||||
|
||||
main()
|
100
egs/slu/prepare.sh
Executable file
100
egs/slu/prepare.sh
Executable file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=3
|
||||
stop_stage=5
|
||||
|
||||
data_dir=/home/xli257/slu/fluent_speech_commands_dataset
|
||||
|
||||
lang_dir=data/lang_phone
|
||||
lm_dir=data/lm
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
mkdir -p $lang_dir
|
||||
mkdir -p $lm_dir
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "data_dir: $data_dir"
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare slu manifest"
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare slu $data_dir data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute fbank for yesno"
|
||||
mkdir -p data/fbank
|
||||
python ./local/compute_fbank_slu.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare lang"
|
||||
# NOTE: "<UNK> SIL" is added for implementation convenience
|
||||
# as the graph compiler code requires that there is a OOV word
|
||||
# in the lexicon.
|
||||
python ./local/generate_lexicon.py $data_dir $lm_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Train LM"
|
||||
# We use a unigram G
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 1 \
|
||||
-text $lm_dir/words_transcript.txt \
|
||||
-lm $lm_dir/G_transcript.arpa
|
||||
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 1 \
|
||||
-text $lm_dir/words_frames.txt \
|
||||
-lm $lm_dir/G_frames.arpa
|
||||
|
||||
python ./local/prepare_lang.py
|
||||
|
||||
if [ ! -f $lm_dir/G_transcript.fst.txt ]; then
|
||||
python -m kaldilm \
|
||||
--read-symbol-table="$lm_dir/words_transcript.txt" \
|
||||
$lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lm_dir/G_frames.fst.txt ]; then
|
||||
python -m kaldilm \
|
||||
--read-symbol-table="$lm_dir/words_frames.txt" \
|
||||
$lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt
|
||||
fi
|
||||
|
||||
mkdir -p $lm_dir/frames
|
||||
mkdir -p $lm_dir/transcript
|
||||
|
||||
chmod -R +777 .
|
||||
|
||||
for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt;
|
||||
do
|
||||
j=${i//"_frames"/}
|
||||
mv "$lm_dir/$i" $lm_dir/frames/$j
|
||||
done
|
||||
|
||||
for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt;
|
||||
do
|
||||
j=${i//"_transcript"/}
|
||||
mv "$lm_dir/$i" $lm_dir/transcript/$j
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compile HLG"
|
||||
./local/compile_hlg.py --lang-dir $lm_dir/frames
|
||||
./local/compile_hlg.py --lang-dir $lm_dir/transcript
|
||||
|
||||
fi
|
1
egs/slu/shared
Symbolic link
1
egs/slu/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../icefall/shared/
|
@ -33,7 +33,7 @@ parser.add_argument(
|
||||
"-ngram-order",
|
||||
type=int,
|
||||
default=4,
|
||||
choices=[2, 3, 4, 5, 6, 7],
|
||||
choices=[1, 2, 3, 4, 5, 6, 7],
|
||||
help="Order of n-gram",
|
||||
)
|
||||
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
|
||||
@ -105,7 +105,7 @@ class NgramCounts:
|
||||
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
|
||||
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
|
||||
def __init__(self, ngram_order, bos_symbol="<s>", eos_symbol="</s>"):
|
||||
assert ngram_order >= 2
|
||||
# assert ngram_order >= 2
|
||||
|
||||
self.ngram_order = ngram_order
|
||||
self.bos_symbol = bos_symbol
|
||||
@ -169,7 +169,7 @@ class NgramCounts:
|
||||
with open(filename, encoding=default_encoding) as fp:
|
||||
for line in fp:
|
||||
line = line.strip(strip_chars)
|
||||
self.add_raw_counts_from_line(line)
|
||||
self.add_raw_counts_from_line(line.split()[0])
|
||||
lines_processed += 1
|
||||
if lines_processed == 0 or args.verbose > 0:
|
||||
print(
|
||||
|
Loading…
x
Reference in New Issue
Block a user