From f7fac705d5ae1340b2720de6faf1ff3e0bbd239f Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:22:28 +0800 Subject: [PATCH] minor updates --- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 16 +-- egs/swbd/ASR/conformer_ctc/train.py | 2 +- egs/swbd/ASR/local/compute_fbank_eval2000.py | 138 +++++++++++++++++++ egs/swbd/ASR/local/compute_fbank_swbd.py | 106 +++++++------- egs/swbd/ASR/prepare.sh | 45 +++++- 5 files changed, 242 insertions(+), 65 deletions(-) create mode 100755 egs/swbd/ASR/local/compute_fbank_eval2000.py diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index 41da5ab3d..7fb4515e1 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -225,8 +225,6 @@ class SwitchBoardAsrDataModule: else: logging.info("Disable MUSAN") - cuts_train = cuts_train.trim_to_supervisions(keep_overlapping=False) - if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " @@ -393,18 +391,16 @@ class SwitchBoardAsrDataModule: @lru_cache() def train_all_cuts(self) -> CutSet: logging.info("switchboard: About to get train cuts") - return ( - load_manifest_lazy(self.args.manifest_dir / "swbd" / "swbd_cuts_all.jsonl.gz") - .subset(last=2388) - ) + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(last=2388) @lru_cache() def dev_cuts(self) -> CutSet: logging.info("switchboard: About to get dev cuts") - return ( - load_manifest_lazy(self.args.manifest_dir / "swbd" / "swbd_cuts_all.jsonl.gz") - .subset(first=50) - ) + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(first=50) @lru_cache() def test_eval2000_cuts(self) -> CutSet: diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py index 370fa0f77..8ed837aaf 100755 --- a/egs/swbd/ASR/conformer_ctc/train.py +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -697,7 +697,7 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 600.0 + return 1.0 <= c.duration train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py new file mode 100755 index 000000000..c6e2890ff --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + manifests = { + "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", + } + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = CutSet.from_file(manifests[prefix]).resample(16000) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_switchboard( + dir_name="eval2000", + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py index a89130bae..dd82220c0 100755 --- a/egs/swbd/ASR/local/compute_fbank_swbd.py +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -70,18 +70,25 @@ def get_args(): help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", ) + parser.add_argument( + "--split-index", + type=int, + required=True, + ) + return parser.parse_args() def compute_fbank_switchboard( dir_name: str, + split_index: int, bpe_model: Optional[str] = None, dataset: Optional[str] = None, perturb_speed: Optional[bool] = True, ): src_dir = Path(f"data/manifests/{dir_name}") - output_dir = Path(f"data/fbank/{dir_name}") - num_jobs = min(15, os.cpu_count()) + output_dir = Path(f"data/fbank/{dir_name}_split16") + num_jobs = min(1, os.cpu_count()) num_mel_bins = 80 if bpe_model: @@ -96,54 +103,48 @@ def compute_fbank_switchboard( prefix = dir_name suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None + split_dir = Path("data/manifests/swbd_split16/") - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=8000)) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], + partition = "all" + cuts_filename = ( + f"{prefix}_cuts_{partition}.{str(split_index).zfill(2)}.{suffix}" + ) + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = ( + CutSet.from_file( + split_dir + / f"swbd_train_all_trimmed.{str(split_index).zfill(2)}.jsonl.gz" ) + .resample(16000) + .to_eager() + .filter(lambda c: c.duration > 2.0) + ) - if "train" in partition: - if bpe_model: - cut_set = filter_cuts(cut_set, sp) - if perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}_{str(split_index).zfill(2)}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None, + ) + cut_set.to_file(output_dir / cuts_filename) if __name__ == "__main__": @@ -152,10 +153,11 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - for dir_name in ["swbd", "eval2000"]: - compute_fbank_switchboard( - dir_name=dir_name, - bpe_model=args.bpe_model, - dataset=args.dataset, - perturb_speed=args.perturb_speed, - ) + + compute_fbank_switchboard( + dir_name="swbd", + split_index=args.split_index, + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 1ac5e8a1e..50af97225 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -76,10 +76,37 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz + lhotse cut simple \ + -r data/manifests/swbd/swbd_recordings_all.jsonl.gz \ + -s data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all.jsonl.gz + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/swbd/swbd_train_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz + + num_splits=16 + mkdir -p data/manifests/swbd_split${num_splits} + lhotse split ${num_splits} \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz \ + data/manifests/swbd_split${num_splits} + lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 ./local/normalize_eval2000.py \ data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/eval2000/eval2000_recordings_all.jsonl.gz \ + -s data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz + + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz # ./local/rt03_data_prep.sh $rt03_dir @@ -116,13 +143,27 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for SwitchBoard" - mkdir -p data/fbank if [ ! -e data/fbank/.swbd.done ]; then - ./local/compute_fbank_swbd.py + mkdir -p data/fbank/swbd_split${num_splits}/ + for index in $(seq 1 16); do + ./local/compute_fbank_swbd.py --split-index ${index} & + done + wait + pieces=$(find data/fbank/swbd_split${num_splits} -name "swbd_cuts_all.*.jsonl.gz") + lhotse combine $pieces data/fbank/swbd_cuts_all.jsonl.gz touch data/fbank/.swbd.done fi fi +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for eval2000" + if [ ! -e data/fbank/.eval2000.done ]; then + mkdir -p data/fbank/eval2000/ + ./local/compute_fbank_eval2000.py + touch data/fbank/.eval2000.done + fi +fi + if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for musan" mkdir -p data/fbank