update the PR#1374 (feedback from @csukuangfj)

- fixing .py headers and docstrings
- removing BUT specific parts of `prepare.sh`
- adding assert `num_jobs >= num_workers` to `compute_fbank.py`
- narrowing list of languages
  (let's limit to ASR sets with transcripts for now)
- added links to `README.md`
- extending `text_from_manifest.py`
This commit is contained in:
Karel Vesely 2023-11-13 17:23:24 +01:00
parent 4ec48f30b1
commit 07a229ac81
10 changed files with 96 additions and 95 deletions

View File

@ -1,6 +1,8 @@
# Readme
This recipe contains data preparation for the VoxPopuli dataset.
This recipe contains data preparation for the
[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset
[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf).
At the moment, without model training.

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Brno University of Technology (authors: Karel Veselý)
# 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -36,23 +36,29 @@ located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`.
The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats`
and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`.
The number of workers is smaller than nunber of jobs
Typically, the number of workers is smaller than number of jobs
(see --num-jobs 100 --num-workers 25 in the example).
And, the number of jobs should be at least the number of workers (it's checked).
"""
import argparse
import logging
import multiprocessing
import os
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
import sentencepiece as spm
import torch
from filter_cuts import filter_cuts
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import is_caching_enabled, set_caching_enabled
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
is_caching_enabled,
set_caching_enabled,
)
from icefall.utils import str2bool
@ -128,7 +134,6 @@ def get_args():
def compute_fbank_features(args: argparse.Namespace):
set_caching_enabled(True) # lhotse
src_dir = Path(args.src_dir)
@ -181,6 +186,7 @@ def compute_fbank_features(args: argparse.Namespace):
# We typically use `num_jobs=100, num_workers=20`
# - this is helpful for large databases
# - both values are configurable externally
assert num_jobs >= num_workers, (num_jobs, num_workers)
executor = ProcessPoolExecutor(
max_workers=num_workers,
mp_context=multiprocessing.get_context("spawn"),
@ -202,7 +208,7 @@ def compute_fbank_features(args: argparse.Namespace):
# correct small deviations of duration, caused by speed-perturbation
for cut in cut_set:
assert len(cut.supervisions) == 1
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id)
duration_difference = abs(cut.supervisions[0].duration - cut.duration)
tolerance = 0.02 # 20ms
if duration_difference == 0.0:
@ -211,7 +217,7 @@ def compute_fbank_features(args: argparse.Namespace):
logging.info(
"small mismatch of the supervision duration "
f"(Δt = {duration_difference*1000}ms), "
f"corretcing : cut.duration {cut.duration} -> "
f"correcting : cut.duration {cut.duration} -> "
f"supervision {cut.supervisions[0].duration}"
)
cut.supervisions[0].duration = cut.duration

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Brno University of Technology (authors: Karel Veselý)
# 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#

View File

@ -1,5 +1,5 @@
#!/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -25,11 +25,11 @@ Usage example:
"""
import argparse
import gzip
import json
import logging
import sys
import gzip
import re
import sys
def get_args():
@ -54,7 +54,6 @@ def main():
total_n_utts = 0
for fname in args.filename:
if fname == "-":
fd = sys.stdin
elif re.match(r".*\.jsonl\.gz$", fname):

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
# 2023 Brno University of Technology (author: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -19,8 +20,8 @@
Preprocess the database.
- Convert RecordingSet and SupervisionSet to CutSet.
- Apply text normalization to the transcripts.
- We take renormzlized `orig_text` as `text` transcripts.
- The the text normalization is separating punctuation from words.
- We take renormalized `orig_text` as `text` transcripts.
- The text normalization is separating punctuation from words.
- Also we put capital letter to the beginning of a sentence.
The script is inspired in:
@ -40,12 +41,12 @@ from typing import Optional
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# from local/
from separate_punctuation import separate_punctuation
from uppercase_begin_of_sentence import UpperCaseBeginOfSentence
from icefall.utils import str2bool
def get_args():
parser = argparse.ArgumentParser()

View File

@ -1,5 +1,5 @@
#!/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -20,7 +20,7 @@ Example:
input: "This is fine. Yes, you are right."
output: "This is fine . Yes , you are right ."
The script also handles exceptions in a hard-coded fasion.
The script also handles exceptions in a hard-coded fashion.
(same functionality could be done with `nltk.tokenize.word_tokenize()`,
but that would be an extra dependency)
@ -28,17 +28,18 @@ The script also handles exceptions in a hard-coded fasion.
It can be used as a module, or as an executable script.
Usage example #1:
from separate_punctuation import separate_punctuation
`from separate_punctuation import separate_punctuation`
Usage example #2:
```
python3 ./local/separate_punctuation.py \
--ignore-columnts 1 \
${kaldi_data}/text
--ignore-columns 1 \
< ${kaldi_data}/text
```
"""
import sys
import re
import sys
from argparse import ArgumentParser
@ -67,10 +68,8 @@ def separate_punctuation(text: str) -> str:
# re-join the special cases of punctuation
for ii, tok in enumerate(tokens):
# no rewriting for 1st and last token
if ii > 0 and ii < len(tokens) - 1:
# **RULES ADDED FOR CZECH COMMON VOICE**
# fix "27 . dubna" -> "27. dubna", but keep punctuation separate,

View File

@ -1,5 +1,5 @@
#!/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Print the text contained in `supervisions.jsonl.gz`.
Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`.
Usage example:
python3 ./local/text_from_manifest.py \
@ -23,8 +23,8 @@ Usage example:
"""
import argparse
import json
import gzip
import json
def get_args():
@ -41,7 +41,13 @@ def main():
with gzip.open(args.filename, mode="r") as fd:
for line in fd:
js = json.loads(line)
print(js["text"])
if "text" in js:
print(js["text"]) # supervisions.jsonl.gz
elif "supervisions" in js:
for s in js["supervisions"]:
print(s["text"]) # cuts.jsonl.gz
else:
raise Exception(f"Unknown jsonl format of {args.filename}")
if __name__ == "__main__":

View File

@ -1,5 +1,5 @@
#!/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -19,17 +19,18 @@ This script introduces initial capital letter at the beginning of a sentence.
It can be used as a module, or as an executable script.
Usage example #1:
from uppercase_begin_of_sentence import UpperCaseBeginOfSentence
`from uppercase_begin_of_sentence import UpperCaseBeginOfSentence`
Usage example #2:
```
python3 ./local/uppercase_begin_of_sentence.py \
--ignore-columnts 1 \
${kaldi_data}/text
--ignore-columns 1 \
< ${kaldi_data}/text
```
"""
import re
import sys
from argparse import ArgumentParser
@ -44,7 +45,6 @@ class UpperCaseBeginOfSentence:
"""
def __init__(self):
# The 1st word will have Title-case
# This variable transfers context from previous line
self.prev_token_is_punct = True
@ -59,7 +59,6 @@ class UpperCaseBeginOfSentence:
punct_set = set([".", "!", "?"])
for ii, w in enumerate(words):
# punctuation ?
if w in punct_set:
self.prev_token_is_punct = True

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -18,7 +19,8 @@
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
- Supervision time bounds are within cut time bounds
- Supervision time bounds are within Cut time bounds
- Duration of Cut and Superivion are equal
We will add more checks later if needed.
@ -27,14 +29,13 @@ Usage example:
python3 ./local/validate_manifest.py \
./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
(Based on: `librispeech/ASR/local/validate_manifest.py`)
"""
import argparse
import logging
from pathlib import Path
from icefall.utils import setup_logger
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset.speech_recognition import validate_for_asr
@ -49,14 +50,6 @@ def get_args():
help="Path to the manifest file",
)
parser.add_argument(
"--log-file",
type=str,
default=None,
required=True,
help="The filename to save the log.",
)
return parser.parse_args()
@ -101,8 +94,6 @@ def main():
args = get_args()
manifest = args.cutset_manifest
setup_logger(log_filename=f"{args.log_file}", log_level="info")
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
@ -125,4 +116,8 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env bash
. /mnt/matylda5/iveselyk/ASR_TOOLKITS/K2_SHERPA_PYTORCH20/conda-activate.sh
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -euxo pipefail
@ -8,28 +9,17 @@ nj=20
stage=-1
stop_stage=100
# Split data/${lang}set to this number of pieces
# This is to avoid OOM during feature extraction.
num_splits=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# [TODO update this]
# - $dl_dir/voxpopuli/raw_audios/$lang/$year
# This directory contains *.ogg files with audio downloaded and extracted from archives:
# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar
#
# - $dl_dir/$release/$lang
# This directory contains the following files downloaded from
# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz
#
# - clips
# - dev.tsv
# - invalidated.tsv
# - other.tsv
# - reported.tsv
# - test.tsv
# - train.tsv
# - validated.tsv
# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder
# as part of `lhotse prepare voxpopuli` from:
# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
@ -39,19 +29,19 @@ num_splits=100
# - noise
# - speech
#dl_dir=$PWD/download
dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA
dl_dir=$PWD/download
#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT
#musan_dir=${dl_dir}/musan
musan_dir=/mnt/matylda2/data/MUSAN
musan_dir=${dl_dir}/musan
#musan_dir=/mnt/matylda2/data/MUSAN # BUT
# Choose vlues from:
# Choose value from ASR_LANGUAGES:
#
# "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr",
# "sk", "sl", "et", "lt", "pt", "bg", "el", "lv", "mt", "sv", "da",
# "asr", "10k", "100k", "400k"
# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr",
# "sk", "sl", "et", "lt" ]
#
# See: https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/bin/modes/recipes/voxpopuli.py#L77
# See ASR_LANGUAGES in:
# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4
lang=en
task=asr
@ -102,12 +92,6 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
if [ ! -d $musan_dir/musan ]; then
lhotse download musan $musan_dir
fi
# pre-download the transcripts
DOWNLOAD_BASE_URL="https://dl.fbaipublicfiles.com/voxpopuli"
dir=data/manifests; mkdir -p ${dir}
wget --tries=10 --continue --progress=bar --directory-prefix=${dir} \
"${DOWNLOAD_BASE_URL}/annotations/asr/${task}_${lang}.tsv.gz"
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
@ -115,7 +99,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the VoxPopuli corpus
# to $dl_dir/voxpopuli
if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then
# Warning : it requires Internet connection (it downloads transcripts)
# Warning : it requires Internet connection (it downloads transcripts to ${tmpdir})
lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests
touch data/manifests/.voxpopuli-${task}-${lang}.done
fi
@ -150,7 +134,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
for dataset in "dev" "test"; do
if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then
./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
--num-jobs 50 --num-workers 10 \
--num-jobs 50 --num-workers ${nj} \
--prefix "voxpopuli-${task}-${lang}" \
--dataset ${dataset} \
--trim-to-supervisions True
@ -160,10 +144,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 6: Compute fbank for train set of VoxPopuli"
log "Stage 5: Compute fbank for train set of VoxPopuli"
if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then
./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
--num-jobs 100 --num-workers 25 \
--num-jobs 100 --num-workers ${nj} \
--prefix "voxpopuli-${task}-${lang}" \
--dataset train \
--trim-to-supervisions True \
@ -173,7 +157,17 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Compute fbank for musan"
log "Stage 6: Validate fbank manifests for VoxPopuli"
for dataset in "dev" "test" "train"; do
mkdir -p data/fbank/log/
./local/validate_cutset_manifest.py \
data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \
2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
@ -181,8 +175,8 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare BPE based lang"
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}_${lang}