From 03853f1ee5bc71b7ff5e36289682a430a196f75e Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 31 May 2023 12:46:17 +0800 Subject: [PATCH] Add peoples_speech (#1101) * update * Small fix * Update egs/peoples_speech/ASR/prepare.sh Co-authored-by: Fangjun Kuang * limit normalize log * Update egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py Co-authored-by: Fangjun Kuang * Update compute_fbank_peoples_speech_splits.py * Update compute_fbank_peoples_speech_valid_test.py --------- Co-authored-by: Fangjun Kuang --- .../ASR/local/compute_fbank_musan.py | 1 + .../compute_fbank_peoples_speech_splits.py | 154 +++++++++++ ...compute_fbank_peoples_speech_valid_test.py | 93 +++++++ egs/peoples_speech/ASR/local/filter_cuts.py | 1 + .../ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/preprocess_peoples_speech.py | 123 +++++++++ .../ASR/local/train_bpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/peoples_speech/ASR/prepare.sh | 247 ++++++++++++++++++ egs/peoples_speech/ASR/shared | 1 + 10 files changed, 623 insertions(+) create mode 120000 egs/peoples_speech/ASR/local/compute_fbank_musan.py create mode 100755 egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py create mode 100755 egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py create mode 120000 egs/peoples_speech/ASR/local/filter_cuts.py create mode 120000 egs/peoples_speech/ASR/local/prepare_lang_bpe.py create mode 100755 egs/peoples_speech/ASR/local/preprocess_peoples_speech.py create mode 120000 egs/peoples_speech/ASR/local/train_bpe_model.py create mode 120000 egs/peoples_speech/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/peoples_speech/ASR/prepare.sh create mode 120000 egs/peoples_speech/ASR/shared diff --git a/egs/peoples_speech/ASR/local/compute_fbank_musan.py b/egs/peoples_speech/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/peoples_speech/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py new file mode 100755 index 000000000..c2ab3d07d --- /dev/null +++ b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + set_audio_duration_mismatch_tolerance, + set_caching_enabled, +) + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--num-splits", + type=int, + required=True, + help="The number of splits of the train subset", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + + return parser.parse_args() + + +def compute_fbank_peoples_speech_splits(args): + subsets = ("dirty", "dirty_sa", "clean", "clean_sa") + num_splits = args.num_splits + output_dir = f"data/fbank/peoples_speech_train_split" + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + num_digits = 8 + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance + set_caching_enabled(False) + + for partition in subsets: + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {partition}: {idx}") + + cuts_path = output_dir / f"peoples_speech_cuts_{partition}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = ( + output_dir / f"peoples_speech_cuts_{partition}_raw.{idx}.jsonl.gz" + ) + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/peoples_speech_feats_{partition}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_peoples_speech_splits(args) + + +if __name__ == "__main__": + main() diff --git a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py new file mode 100755 index 000000000..89f43a674 --- /dev/null +++ b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the People's Speech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_peoples_speech_valid_test(): + src_dir = Path(f"data/manifests") + output_dir = Path(f"data/fbank") + num_workers = 42 + batch_duration = 600 + + subsets = ("validation", "test") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + logging.info(f"device: {device}") + + for partition in subsets: + cuts_path = output_dir / f"peoples_speech_cuts_{partition}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + + raw_cuts_path = output_dir / f"peoples_speech_cuts_{partition}_raw.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/peoples_speech_feats_{partition}", + num_workers=num_workers, + batch_duration=batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_peoples_speech_valid_test() diff --git a/egs/peoples_speech/ASR/local/filter_cuts.py b/egs/peoples_speech/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/peoples_speech/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/prepare_lang_bpe.py b/egs/peoples_speech/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/peoples_speech/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/preprocess_peoples_speech.py b/egs/peoples_speech/ASR/local/preprocess_peoples_speech.py new file mode 100755 index 000000000..c5417049f --- /dev/null +++ b/egs/peoples_speech/ASR/local/preprocess_peoples_speech.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import re +from pathlib import Path +from typing import Optional + +from lhotse import CutSet, SupervisionSegment +from lhotse.recipes.utils import read_manifests_if_cached + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + return parser.parse_args() + + +def normalize_text(utt: str) -> str: + utt = re.sub(r"[{0}]+".format("-"), " ", utt) + return re.sub(r"[^a-zA-Z\s]", "", utt).upper() + + +def preprocess_peoples_speech(dataset: Optional[str] = None): + src_dir = Path(f"data/manifests") + output_dir = Path(f"data/fbank") + output_dir.mkdir(exist_ok=True) + + if dataset is None: + dataset_parts = ( + "validation", + "test", + "dirty", + "dirty_sa", + "clean", + "clean_sa", + ) + else: + dataset_parts = dataset.split(" ", -1) + + logging.info("Loading manifest, it may takes 8 minutes") + prefix = f"peoples_speech" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + prefix=prefix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + logging.info(f"Normalizing text in {partition}") + i = 0 + for sup in m["supervisions"]: + text = str(sup.text) + orig_text = text + sup.text = normalize_text(sup.text) + text = str(sup.text) + if i < 10 and len(orig_text) != len(text): + logging.info( + f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" + ) + i += 1 + + # Create long-recording cut manifests. + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ).resample(16000) + + # Run data augmentation that needs to be done in the + # time domain. + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + preprocess_peoples_speech(dataset=args.dataset) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/peoples_speech/ASR/local/train_bpe_model.py b/egs/peoples_speech/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/peoples_speech/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/validate_bpe_lexicon.py b/egs/peoples_speech/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/peoples_speech/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/prepare.sh b/egs/peoples_speech/ASR/prepare.sh new file mode 100755 index 000000000..3787858d9 --- /dev/null +++ b/egs/peoples_speech/ASR/prepare.sh @@ -0,0 +1,247 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=32 +stage=-1 +stop_stage=100 + +# Split data/set to a number of pieces +# This is to avoid OOM during feature extraction. +num_per_split=4000 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/peoples_speech +# This directory contains the following files downloaded from +# https://huggingface.co/datasets/MLCommons/peoples_speech +# +# - test +# - train +# - validation +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/peoples_speech, + # you can create a symlink + # + # ln -sfv /path/to/peoples_speech $dl_dir/peoples_speech + # + if [ ! -d $dl_dir/peoples_speech/train ]; then + git lfs install + git clone https://huggingface.co/datasets/MLCommons/peoples_speech + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare People's Speech manifest" + # We assume that you have downloaded the People's Speech corpus + # to $dl_dir/peoples_speech + mkdir -p data/manifests + if [ ! -e data/manifests/.peoples_speech.done ]; then + lhotse prepare peoples-speech -j $nj $dl_dir/peoples_speech data/manifests + touch data/manifests/.peoples_speech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Preprocess People's Speech manifest" + mkdir -p data/fbank + if [ ! -e data/fbank/.preprocess_complete ]; then + ./local/preprocess_peoples_speech.py + touch data/fbank/.preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for valid and test subsets of People's Speech" + if [ ! -e data/fbank/.peoples_speech_valid_test.done ]; then + ./local/compute_fbank_peoples_speech_valid_test.py + touch data/fbank/.peoples_speech_valid_test.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split train subset into pieces" + split_dir=data/fbank/peoples_speech_train_split + if [ ! -e $split_dir/.peoples_speech_dirty_split.done ]; then + lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz $split_dir $num_per_split + touch $split_dir/.peoples_speech_dirty_split.done + fi + + if [ ! -e $split_dir/.peoples_speech_dirty_sa_split.done ]; then + lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz $split_dir $num_per_split + touch $split_dir/.peoples_speech_dirty_sa_split.done + fi + + if [ ! -e $split_dir/.peoples_speech_clean_split.done ]; then + lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz $split_dir $num_per_split + touch $split_dir/.peoples_speech_clean_split.done + fi + + if [ ! -e $split_dir/.peoples_speech_clean_sa_split.done ]; then + lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz $split_dir $num_per_split + touch $split_dir/.peoples_speech_clean_sa_split.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compute features for train subset of People's Speech" + if [ ! -e data/fbank/.peoples_speech_train.done ]; then + ./local/compute_fbank_peoples_speech_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --start 0 \ + --num-splits 2000 + touch data/fbank/.peoples_speech_train.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + file=$( + find "data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz" + find "data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz" + find "data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz" + find "data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz" + ) + gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + + # Ensure space only appears once + sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + sed -i 's/ +/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi diff --git a/egs/peoples_speech/ASR/shared b/egs/peoples_speech/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/peoples_speech/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file