add the voxpopuli recipe (#1374)

* add the `voxpopuli` recipe

- this is the data preparation
- there is no ASR training and no results

* update the PR#1374 (feedback from @csukuangfj)

- fixing .py headers and docstrings
- removing BUT specific parts of `prepare.sh`
- adding assert `num_jobs >= num_workers` to `compute_fbank.py`
- narrowing list of languages
  (let's limit to ASR sets with transcripts for now)
- added links to `README.md`
- extending `text_from_manifest.py`
This commit is contained in:
Karel Vesely 2023-11-16 07:38:31 +01:00 committed by GitHub
parent 6d275ddf9f
commit 59c943878f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1296 additions and 0 deletions

View File

@ -0,0 +1,38 @@
# Readme
This recipe contains data preparation for the
[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset
[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf).
At the moment, without model training.
## audio per language
| language | Size | Hrs. untranscribed | Hrs. transcribed |
|----------|--------|--------------------|------------------|
| bg | 295G | 17.6K | - |
| cs | 308G | 18.7K | 62 |
| da | 233G | 13.6K | - |
| de | 379G | 23.2K | 282 |
| el | 305G | 17.7K | - |
| en | 382G | 24.1K | 543 |
| es | 362G | 21.4K | 166 |
| et | 179G | 10.6K | 3 |
| fi | 236G | 14.2K | 27 |
| fr | 376G | 22.8K | 211 |
| hr | 132G | 8.1K | 43 |
| hu | 297G | 17.7K | 63 |
| it | 361G | 21.9K | 91 |
| lt | 243G | 14.4K | 2 |
| lv | 217G | 13.1K | - |
| mt | 147G | 9.1K | - |
| nl | 322G | 19.0K | 53 |
| pl | 348G | 21.2K | 111 |
| pt | 300G | 17.5K | - |
| ro | 296G | 17.9K | 89 |
| sk | 201G | 12.1K | 35 |
| sl | 190G | 11.3K | 10 |
| sv | 272G | 16.3K | - |
| | | | |
| total | 6.3T | 384K | 1791 |

View File

@ -0,0 +1,248 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of VoxPopuli dataset.
Usage example:
python3 ./local/compute_fbank.py \
--src-dir data/fbank --output-dir data/fbank \
--num-jobs 100 --num-workers 25 \
--prefix "voxpopuli-${task}-${lang}" \
--dataset train \
--trim-to-supervisions True \
--speed-perturb True
It looks for raw CutSet in the directory data/fbank
located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`.
The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats`
and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`.
Typically, the number of workers is smaller than number of jobs
(see --num-jobs 100 --num-workers 25 in the example).
And, the number of jobs should be at least the number of workers (it's checked).
"""
import argparse
import logging
import multiprocessing
import os
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
import sentencepiece as spm
import torch
from filter_cuts import filter_cuts
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
is_caching_enabled,
set_caching_enabled,
)
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to the bpe.model. If not None, we will remove short and
long utterances before extracting features""",
)
parser.add_argument(
"--src-dir",
type=str,
help="""Folder with the input manifest files.""",
default="data/manifests",
)
parser.add_argument(
"--output-dir",
type=str,
help="""Folder with the output manifests (cuts) and feature files.""",
default="data/fbank",
)
parser.add_argument(
"--prefix",
type=str,
help="""Prefix of the manifest files.""",
default="",
)
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank (train,test,dev).""",
default=None,
)
parser.add_argument(
"--num-jobs",
type=int,
help="""Number of jobs (i.e. files with extracted features)""",
default=50,
)
parser.add_argument(
"--num-workers",
type=int,
help="""Number of parallel workers""",
default=10,
)
parser.add_argument(
"--speed-perturb",
type=str2bool,
default=False,
help="""Enable speed perturbation for the set.""",
)
parser.add_argument(
"--trim-to-supervisions",
type=str2bool,
default=False,
help="""Apply `trim-to-supervision` to cut set.""",
)
return parser.parse_args()
def compute_fbank_features(args: argparse.Namespace):
set_caching_enabled(True) # lhotse
src_dir = Path(args.src_dir)
output_dir = Path(args.output_dir)
num_jobs = args.num_jobs
num_workers = min(args.num_workers, os.cpu_count())
num_mel_bins = 80
bpe_model = args.bpe_model
if bpe_model:
logging.info(f"Loading {bpe_model}")
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
prefix = args.prefix # "ELEF_TRAIN"
dataset = args.dataset
suffix = "jsonl.gz"
cuts_raw_filename = Path(f"{src_dir}/{prefix}_cuts_{dataset}_raw.{suffix}")
cuts_raw = CutSet.from_file(cuts_raw_filename)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
cuts_filename = Path(f"{prefix}_cuts_{dataset}.{suffix}")
if (output_dir / cuts_filename).is_file():
logging.info(f"{output_dir/cuts_filename} already exists - skipping.")
return
logging.info(f"Processing {output_dir/cuts_filename}")
cut_set = cuts_raw
if bpe_model:
cut_set = filter_cuts(cut_set, sp)
if args.speed_perturb:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
if args.trim_to_supervisions:
logging.info(f"About to `trim_to_supervisions()` {output_dir / cuts_filename}")
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
else:
logging.info(
"Not doing `trim_to_supervisions()`, "
"to enable use --trim-to-supervision=True"
)
cut_set = cut_set.to_eager() # disallow lazy evaluation (sorting requires it)
cut_set = cut_set.sort_by_recording_id() # enhances AudioCache hit rate
# We typically use `num_jobs=100, num_workers=20`
# - this is helpful for large databases
# - both values are configurable externally
assert num_jobs >= num_workers, (num_jobs, num_workers)
executor = ProcessPoolExecutor(
max_workers=num_workers,
mp_context=multiprocessing.get_context("spawn"),
initializer=set_caching_enabled,
initargs=(is_caching_enabled(),),
)
logging.info(
f"executor {executor} : num_workers {num_workers}, num_jobs {num_jobs}"
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir / prefix}-{dataset}_feats",
num_jobs=num_jobs,
executor=executor,
storage_type=LilcomChunkyWriter,
)
# correct small deviations of duration, caused by speed-perturbation
for cut in cut_set:
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id)
duration_difference = abs(cut.supervisions[0].duration - cut.duration)
tolerance = 0.02 # 20ms
if duration_difference == 0.0:
pass
elif duration_difference <= tolerance:
logging.info(
"small mismatch of the supervision duration "
f"(Δt = {duration_difference*1000}ms), "
f"correcting : cut.duration {cut.duration} -> "
f"supervision {cut.supervisions[0].duration}"
)
cut.supervisions[0].duration = cut.duration
else:
logging.error(
"mismatch of cut/supervision duration "
f"(Δt = {duration_difference*1000}ms) : "
f"cut.duration {cut.duration}, "
f"supervision {cut.supervisions[0].duration}"
)
raise ValueError(
"mismatch of cut/supervision duration "
f"(Δt = {duration_difference*1000}ms)"
)
# store the cutset
logging.info(f"storing CutSet to : `{output_dir / cuts_filename}`")
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
compute_fbank_features(args)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
Usage example:
python3 ./local/display_manifest_statistics.py data/fbank/*_cuts*.jsonl.gz
See the function `remove_short_and_long_utt()` in transducer/train.py
for usage.
"""
import argparse
from lhotse import load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser("Compute statistics for 'cuts' .jsonl.gz")
parser.add_argument(
"filename",
help="data/fbank/imported_cuts_bison-train_trim.jsonl.gz",
)
return parser.parse_args()
def main():
args = get_args()
cuts = load_manifest_lazy(args.filename)
cuts.describe()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,93 @@
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script computes durations of datasets from
the SupervisionSet manifests.
Usage example:
python3 ./local/duration_from_supervision_manifest.py \
data/manifest/*_superivions*.jsonl.gz
"""
import argparse
import gzip
import json
import logging
import re
import sys
def get_args():
parser = argparse.ArgumentParser(
"Read the raw text from the 'supervisions.jsonl.gz'"
)
parser.add_argument(
"filename",
help="supervisions.jsonl.gz",
nargs="+",
)
return parser.parse_args()
def main():
args = get_args()
logging.info(vars(args))
total_duration = 0.0
total_n_utts = 0
for fname in args.filename:
if fname == "-":
fd = sys.stdin
elif re.match(r".*\.jsonl\.gz$", fname):
fd = gzip.open(fname, mode="r")
else:
fd = open(fname, mode="r")
fname_duration = 0.0
n_utts = 0
for line in fd:
js = json.loads(line)
fname_duration += js["duration"]
n_utts += 1
print(
f"Duration: {fname_duration/3600:7.2f} hours "
f"(eq. {fname_duration:7.0f} seconds, {n_utts} utts): {fname}"
)
if fd != sys.stdin:
fd.close()
total_duration += fname_duration
total_n_utts += n_utts
print(
f"Total duration: {total_duration/3600:7.2f} hours "
f"(eq. {total_duration:7.0f} seconds)"
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/filter_cuts.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1,178 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
# 2023 Brno University of Technology (author: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the database.
- Convert RecordingSet and SupervisionSet to CutSet.
- Apply text normalization to the transcripts.
- We take renormalized `orig_text` as `text` transcripts.
- The text normalization is separating punctuation from words.
- Also we put capital letter to the beginning of a sentence.
The script is inspired in:
`egs/commonvoice/ASR/local/preprocess_commonvoice.py`
Usage example:
python3 ./local/preprocess_voxpopuli.py \
--task asr --lang en
"""
import argparse
import logging
from pathlib import Path
from typing import Optional
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
# from local/
from separate_punctuation import separate_punctuation
from uppercase_begin_of_sentence import UpperCaseBeginOfSentence
from icefall.utils import str2bool
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
default=None,
)
parser.add_argument(
"--task",
type=str,
help="""Task of VoxPopuli""",
default="asr",
)
parser.add_argument(
"--lang",
type=str,
help="""Language of VoxPopuli""",
required=True,
)
parser.add_argument(
"--use-original-text",
type=str2bool,
help="""Use 'original_text' from the annoattaion file,
otherwise 'normed_text' will be used
(see `data/manifests/${task}_${lang}.tsv.gz`).
""",
default=False,
)
return parser.parse_args()
def normalize_text(utt: str) -> str:
utt = UpperCaseBeginOfSentence().process_line_text(separate_punctuation(utt))
return utt
def preprocess_voxpopuli(
task: str,
language: str,
dataset: Optional[str] = None,
use_original_text: bool = False,
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
if dataset is None:
dataset_parts = (
"dev",
"test",
"train",
)
else:
dataset_parts = dataset.split(" ", -1)
logging.info("Loading manifest")
prefix = f"voxpopuli-{task}-{language}"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
suffix=suffix,
prefix=prefix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
if use_original_text:
logging.info("Using 'original_text' from the annotation file.")
logging.info(f"Normalizing text in {partition}")
for sup in m["supervisions"]:
# `orig_text` includes punctuation and true-case
orig_text = str(sup.custom["orig_text"])
# we replace `text` by normalized `orig_text`
sup.text = normalize_text(orig_text)
else:
logging.info("Using 'normed_text' from the annotation file.")
# remove supervisions with empty 'text'
m["supervisions"] = m["supervisions"].filter(lambda sup: len(sup.text) > 0)
# Create cut manifest with long-recordings.
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
).resample(16000)
# Store the cut set incl. the resampling.
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
preprocess_voxpopuli(
task=args.task,
language=args.lang,
dataset=args.dataset,
use_original_text=args.use_original_text,
)
logging.info("Done")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,130 @@
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script chops the punctuation as standalone tokens.
Example:
input: "This is fine. Yes, you are right."
output: "This is fine . Yes , you are right ."
The script also handles exceptions in a hard-coded fashion.
(same functionality could be done with `nltk.tokenize.word_tokenize()`,
but that would be an extra dependency)
It can be used as a module, or as an executable script.
Usage example #1:
`from separate_punctuation import separate_punctuation`
Usage example #2:
```
python3 ./local/separate_punctuation.py \
--ignore-columns 1 \
< ${kaldi_data}/text
```
"""
import re
import sys
from argparse import ArgumentParser
def separate_punctuation(text: str) -> str:
"""
Text filtering function for separating punctuation.
Example:
input: "This is fine. Yes, you are right."
output: "This is fine . Yes , you are right ."
The exceptions for which the punctuation is
not splitted are hard-coded.
"""
# remove non-desired punctuation symbols
text = re.sub('["„“«»]', "", text)
# separate [,.!?;] punctuation from words by space
text = re.sub(r"(\w)([,.!?;])", r"\1 \2", text)
text = re.sub(r"([,.!?;])(\w)", r"\1 \2", text)
# split to tokens
tokens = text.split()
tokens_out = []
# re-join the special cases of punctuation
for ii, tok in enumerate(tokens):
# no rewriting for 1st and last token
if ii > 0 and ii < len(tokens) - 1:
# **RULES ADDED FOR CZECH COMMON VOICE**
# fix "27 . dubna" -> "27. dubna", but keep punctuation separate,
if tok == "." and tokens[ii - 1].isdigit() and tokens[ii + 1].islower():
tokens_out[-1] = tokens_out[-1] + "."
continue
# fix "resp . pak" -> "resp. pak"
if tok == "." and tokens[ii - 1].isalpha() and tokens[ii + 1].islower():
tokens_out[-1] = tokens_out[-1] + "."
continue
# **RULES ADDED FOR ENGLISH COMMON VOICE**
# fix "A ." -> "A."
if tok == "." and re.match(r"^[A-Z]S", tokens[ii - 1]):
tokens_out[-1] = tokens_out[-1] + "."
continue
# fix "Mr ." -> "Mr."
exceptions = set(["Mr", "Mrs", "Ms"])
if tok == "." and tokens[ii - 1] in exceptions:
tokens_out[-1] = tokens_out[-1] + "."
continue
tokens_out.append(tok)
return " ".join(tokens_out)
def get_args():
parser = ArgumentParser(
description="Separate punctuation from words: 'hello.' -> 'hello .'"
)
parser.add_argument(
"--ignore-columns", type=int, default=1, help="skip number of initial columns"
)
return parser.parse_args()
def main():
args = get_args()
max_split = args.ignore_columns
while True:
line = sys.stdin.readline()
if not line:
break
*key, text = line.strip().split(maxsplit=max_split)
text_norm = separate_punctuation(text)
print(" ".join(key), text_norm)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`.
Usage example:
python3 ./local/text_from_manifest.py \
data/manifests/voxpopuli-asr-en_supervisions_dev.jsonl.gz
"""
import argparse
import gzip
import json
def get_args():
parser = argparse.ArgumentParser(
"Read the raw text from the 'supervisions.jsonl.gz'"
)
parser.add_argument("filename", help="supervisions.jsonl.gz")
return parser.parse_args()
def main():
args = get_args()
with gzip.open(args.filename, mode="r") as fd:
for line in fd:
js = json.loads(line)
if "text" in js:
print(js["text"]) # supervisions.jsonl.gz
elif "supervisions" in js:
for s in js["supervisions"]:
print(s["text"]) # cuts.jsonl.gz
else:
raise Exception(f"Unknown jsonl format of {args.filename}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1,113 @@
#!/usr/bin/env python3
# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script introduces initial capital letter at the beginning of a sentence.
It can be used as a module, or as an executable script.
Usage example #1:
`from uppercase_begin_of_sentence import UpperCaseBeginOfSentence`
Usage example #2:
```
python3 ./local/uppercase_begin_of_sentence.py \
--ignore-columns 1 \
< ${kaldi_data}/text
```
"""
import re
import sys
from argparse import ArgumentParser
class UpperCaseBeginOfSentence:
"""
This class introduces initial capital letter at the beginning of a sentence.
Capital letter is used, if previous symbol was punctuation token from
`set([".", "!", "?"])`.
The punctuation as previous token is memorized also across
`process_line_text()` calls.
"""
def __init__(self):
# The 1st word will have Title-case
# This variable transfers context from previous line
self.prev_token_is_punct = True
def process_line_text(self, line_text: str) -> str:
"""
It is assumed that punctuation in `line_text` was already separated,
example: "This is fine . Yes , you are right ."
"""
words = line_text.split()
punct_set = set([".", "!", "?"])
for ii, w in enumerate(words):
# punctuation ?
if w in punct_set:
self.prev_token_is_punct = True
continue
# change case of word...
if self.prev_token_is_punct:
if re.match("<", w):
continue # skip <symbols>
# apply Title-case only on lowercase words.
if w.islower():
words[ii] = w.title()
# change state
self.prev_token_is_punct = False
line_text_uc = " ".join(words)
return line_text_uc
def get_args():
parser = ArgumentParser(
description="Put upper-case at the beginning of a sentence."
)
parser.add_argument(
"--ignore-columns", type=int, default=4, help="skip number of initial columns"
)
return parser.parse_args()
def main():
args = get_args()
uc_bos = UpperCaseBeginOfSentence()
max_split = args.ignore_columns
while True:
line = sys.stdin.readline()
if not line:
break
line = line.strip()
if len(line.split()) > 1:
*key, text = line.strip().split(maxsplit=max_split) # parse,
text_uc = uc_bos.process_line_text(text) # process,
print(" ".join(key), text_uc) # print,
else:
print(line)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_bpe_lexicon.py

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Brno University of Technology (authors: Karel Veselý)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
- Supervision time bounds are within Cut time bounds
- Duration of Cut and Superivion are equal
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
(Based on: `librispeech/ASR/local/validate_manifest.py`)
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset.speech_recognition import validate_for_asr
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"cutset_manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def validate_one_supervision_per_cut(c: Cut):
if len(c.supervisions) != 1:
raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
def validate_supervision_and_cut_time_bounds(c: Cut):
tol = 2e-3 # same tolerance as in 'validate_for_asr()'
s = c.supervisions[0]
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
if s.start < -tol:
raise ValueError(
f"{c.id}: Supervision start time {s.start} must not be negative."
)
if s.start > tol:
raise ValueError(
f"{c.id}: Supervision start time {s.start} "
"is not at the beginning of the Cut. "
"Please apply `lhotse cut trim-to-supervisions`."
)
if c.start + s.end > c.end + tol:
raise ValueError(
f"{c.id}: Supervision end time {c.start+s.end} is larger "
f"than cut end time {c.end}"
)
if s.duration != c.duration:
raise ValueError(
f"{c.id}: Cut duration {c.duration} and supervision duration "
f"{s.duration} must be the same.\n"
f"The difference causes problems in the training code : "
f"+/- 1 frame in `x`, `x_lens` in `Zipformer::forward()`.\n"
f"Did you forget to apply `trim_to_supervisions()` ?"
)
def main():
args = get_args()
manifest = args.cutset_manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet)
try:
for c in cut_set:
validate_one_supervision_per_cut(c)
validate_supervision_and_cut_time_bounds(c)
# Validation from K2 training
# - checks supervision start is 0
# - checks supervision.duration is not longer than cut.duration
# - there is tolerance 2ms
validate_for_asr(cut_set)
except BaseException as e:
logging.error(str(e))
raise
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

257
egs/voxpopuli/ASR/prepare.sh Executable file
View File

@ -0,0 +1,257 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -euxo pipefail
nj=20
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/voxpopuli/raw_audios/$lang/$year
# This directory contains *.ogg files with audio downloaded and extracted from archives:
# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar
#
# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder
# as part of `lhotse prepare voxpopuli` from:
# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT
musan_dir=${dl_dir}/musan
#musan_dir=/mnt/matylda2/data/MUSAN # BUT
# Choose value from ASR_LANGUAGES:
#
# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr",
# "sk", "sl", "et", "lt" ]
#
# See ASR_LANGUAGES in:
# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4
lang=en
task=asr
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/${lang}/lang_bpe_xxx,
# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
500
)
# All files generated by this script are saved in "data/${lang}".
# You can safely remove "data/${lang}" and rerun this script to regenerate it.
mkdir -p data/${lang}
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
log "musan_dir: $musan_dir"
log "task: $task, lang: $lang"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/$release,
# you can create a symlink
#
# ln -sfv /path/to/$release $dl_dir/$release
#
if [ ! -d $dl_dir/voxpopuli/raw_audios/${lang} ]; then
lhotse download voxpopuli --subset $lang $dl_dir/voxpopuli
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $musan_dir/musan ]; then
lhotse download musan $musan_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare VoxPopuli manifest"
# We assume that you have downloaded the VoxPopuli corpus
# to $dl_dir/voxpopuli
if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then
# Warning : it requires Internet connection (it downloads transcripts to ${tmpdir})
lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests
touch data/manifests/.voxpopuli-${task}-${lang}.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
#lhotse prepare musan $dl_dir/musan data/manifests
lhotse prepare musan $musan_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Preprocess VoxPopuli manifest"
mkdir -p data/fbank
if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete ]; then
# recordings + supervisions -> cutset
./local/preprocess_voxpopuli.py --task $task --lang $lang \
--use-original-text True
touch data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for dev and test subsets of VoxPopuli"
mkdir -p data/fbank
for dataset in "dev" "test"; do
if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then
./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
--num-jobs 50 --num-workers ${nj} \
--prefix "voxpopuli-${task}-${lang}" \
--dataset ${dataset} \
--trim-to-supervisions True
touch data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done
fi
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for train set of VoxPopuli"
if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then
./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
--num-jobs 100 --num-workers ${nj} \
--prefix "voxpopuli-${task}-${lang}" \
--dataset train \
--trim-to-supervisions True \
--speed-perturb True
touch data/fbank/.voxpopuli-${task}-${lang}-train.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Validate fbank manifests for VoxPopuli"
for dataset in "dev" "test" "train"; do
mkdir -p data/fbank/log/
./local/validate_cutset_manifest.py \
data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \
2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}_${lang}
mkdir -p $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
file=$(
find "data/fbank/voxpopuli-${task}-${lang}_cuts_train.jsonl.gz"
)
local/text_from_manifest.py $file >$lang_dir/transcript_words.txt
# gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt
# Ensure space only appears once
#sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
#sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/words.txt ]; then
cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_dir/words.txt
(echo '!SIL'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
cat - $lang_dir/words.txt | sort | uniq | awk '
BEGIN {
print "<eps> 0";
}
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_dir/words || exit 1;
mv $lang_dir/words $lang_dir/words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
done
fi

1
egs/voxpopuli/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/