From 22f68dd34497539e1dd9c860760c50b4ecc7ef01 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 27 Nov 2023 14:51:30 +0800 Subject: [PATCH] add gss support --- .../ASR/local/compute_fbank_icmcasr.py | 64 +++---- .../ASR/local/prepare_icmc_enhanced.py | 163 ++++++++++++++++++ egs/icmcasr/ASR/local/prepare_icmc_gss.sh | 99 +++++++++++ egs/icmcasr/ASR/prepare.sh | 19 +- 4 files changed, 310 insertions(+), 35 deletions(-) create mode 100644 egs/icmcasr/ASR/local/prepare_icmc_enhanced.py create mode 100644 egs/icmcasr/ASR/local/prepare_icmc_gss.sh diff --git a/egs/icmcasr/ASR/local/compute_fbank_icmcasr.py b/egs/icmcasr/ASR/local/compute_fbank_icmcasr.py index 6e986cb4f..e5623634b 100755 --- a/egs/icmcasr/ASR/local/compute_fbank_icmcasr.py +++ b/egs/icmcasr/ASR/local/compute_fbank_icmcasr.py @@ -64,12 +64,12 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False): suffix="jsonl.gz", ) # For GSS we already have cuts so we read them directly. - # manifests_gss = read_manifests_if_cached( - # dataset_parts=["train", "dev"], - # output_dir=src_dir, - # prefix="icmcasr-gss", - # suffix="jsonl.gz", - # ) + manifests_gss = read_manifests_if_cached( + dataset_parts=["train", "dev"], + output_dir=src_dir, + prefix="icmcasr-gss", + suffix="jsonl.gz", + ) sampling_rate = 16000 @@ -84,6 +84,10 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False): def _extract_feats( cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool ) -> None: + # check if the features have already been computed + if storage_path.exists() or storage_path.with_suffix(".lca").exists(): + logging.info(f"{storage_path} exists, skipping feature extraction") + return if speed_perturb: logging.info(f"Doing speed perturb") cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) @@ -136,17 +140,17 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False): ) logging.info("Processing train split GSS") - # cuts_gss = ( - # CutSet.from_manifests(**manifests_gss["train"]) - # .trim_to_supervisions(keep_overlapping=False) - # .modify_ids(lambda x: x + "-gss") - # ) - # _extract_feats( - # cuts_gss, - # output_dir / "feats_train_gss", - # src_dir / "cuts_train_gss.jsonl.gz", - # perturb_speed, - # ) + cuts_gss = ( + CutSet.from_manifests(**manifests_gss["train"]) + .trim_to_supervisions(keep_overlapping=False) + .modify_ids(lambda x: x + "-gss") + ) + _extract_feats( + cuts_gss, + output_dir / "feats_train_gss", + src_dir / "cuts_train_gss.jsonl.gz", + perturb_speed, + ) logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") for split in ["dev"]: @@ -176,19 +180,19 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False): storage_type=LilcomChunkyWriter, ) ) - # logging.info(f"Processing {split} GSS") - # cuts_gss = ( - # CutSet.from_manifests(**manifests_gss[split]) - # .trim_to_supervisions(keep_overlapping=False) - # .compute_and_store_features_batch( - # extractor=extractor, - # storage_path=output_dir / f"feats_{split}_gss", - # manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz", - # batch_duration=500, - # num_workers=4, - # storage_type=LilcomChunkyWriter, - # ) - # ) + logging.info(f"Processing {split} GSS") + cuts_gss = ( + CutSet.from_manifests(**manifests_gss[split]) + .trim_to_supervisions(keep_overlapping=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_gss", + manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) diff --git a/egs/icmcasr/ASR/local/prepare_icmc_enhanced.py b/egs/icmcasr/ASR/local/prepare_icmc_enhanced.py new file mode 100644 index 000000000..19f1fab72 --- /dev/null +++ b/egs/icmcasr/ASR/local/prepare_icmc_enhanced.py @@ -0,0 +1,163 @@ +#!/usr/local/bin/python +# -*- coding: utf-8 -*- +# Data preparation for AliMeeting GSS-enhanced dataset. + +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +import re + +from lhotse import Recording, RecordingSet, SupervisionSet +from lhotse.qa import fix_manifests +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import fastcopy +from tqdm import tqdm + +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser(description="ICMC enhanced dataset preparation.") + parser.add_argument( + "manifests_dir", + type=Path, + help="Path to directory containing ICMC manifests.", + ) + parser.add_argument( + "enhanced_dir", + type=Path, + help="Path to enhanced data directory.", + ) + parser.add_argument( + "--num-jobs", + "-j", + type=int, + default=1, + help="Number of parallel jobs to run.", + ) + parser.add_argument( + "--min-segment-duration", + "-d", + type=float, + default=0.0, + help="Minimum duration of a segment in seconds.", + ) + return parser.parse_args() + + +def find_recording_and_create_new_supervision(enhanced_dir, supervision): + """ + Given a supervision (corresponding to original AMI recording), this function finds the + enhanced recording correspoding to the supervision, and returns this recording and + a new supervision whose start and end times are adjusted to match the enhanced recording. + """ + file_name = Path( + f"{supervision.recording_id}-{supervision.speaker}-{round(100*supervision.start):06d}_{round(100*supervision.end):06d}.flac" + ) + save_path = str(enhanced_dir / f"{supervision.recording_id}" / file_name) + # replace re template DX0*C with DXmixC + save_path = re.sub(r"DX0(\d)", r"DXmix", save_path) + # convert it to Path object + save_path = Path(save_path) + if save_path.exists(): + recording = Recording.from_file(save_path) + if recording.duration == 0: + logging.warning(f"Skipping {save_path} which has duration 0 seconds.") + return None + + # Old supervision is wrt to the original recording, we create new supervision + # wrt to the enhanced segment + new_supervision = fastcopy( + supervision, + recording_id=recording.id, + start=0, + duration=recording.duration, + ) + return recording, new_supervision + else: + logging.warning(f"{save_path} does not exist.") + return None + + +def main(args): + # Get arguments + manifests_dir = args.manifests_dir + enhanced_dir = args.enhanced_dir + + # Load manifests from cache if they exist (saves time) + manifests = read_manifests_if_cached( + dataset_parts=["train", "dev"], + output_dir=manifests_dir, + prefix="icmcasr-sdm", + suffix="jsonl.gz", + ) + if not manifests: + raise ValueError( + "AliMeeting SDM manifests not found in {}".format(manifests_dir) + ) + + with ThreadPoolExecutor(args.num_jobs) as ex: + for part in ["train", "dev",]: + logging.info(f"Processing {part}...") + supervisions_orig = manifests[part]["supervisions"].filter( + lambda s: s.duration >= args.min_segment_duration + ) + futures = [] + + for supervision in tqdm( + supervisions_orig, + desc="Distributing tasks", + ): + futures.append( + ex.submit( + find_recording_and_create_new_supervision, + enhanced_dir, + supervision, + ) + ) + + recordings = [] + supervisions = [] + for future in tqdm( + futures, + total=len(futures), + desc="Processing tasks", + ): + result = future.result() + if result is not None: + recording, new_supervision = result + recordings.append(recording) + supervisions.append(new_supervision) + + # Remove duplicates from the recordings + recordings_nodup = {} + for recording in recordings: + if recording.id not in recordings_nodup: + recordings_nodup[recording.id] = recording + else: + logging.warning("Recording {} is duplicated.".format(recording.id)) + recordings = RecordingSet.from_recordings(recordings_nodup.values()) + supervisions = SupervisionSet.from_segments(supervisions) + + recordings, supervisions = fix_manifests( + recordings=recordings, supervisions=supervisions + ) + + logging.info(f"Writing {part} enhanced manifests") + recordings.to_file( + manifests_dir / f"icmcasr-gss_recordings_{part}.jsonl.gz" + ) + supervisions.to_file( + manifests_dir / f"icmcasr-gss_supervisions_{part}.jsonl.gz" + ) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/icmcasr/ASR/local/prepare_icmc_gss.sh b/egs/icmcasr/ASR/local/prepare_icmc_gss.sh new file mode 100644 index 000000000..b3490f9ab --- /dev/null +++ b/egs/icmcasr/ASR/local/prepare_icmc_gss.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# This script is used to run GSS-based enhancement on AMI data. +set -euo pipefail +nj=1 +stage=0 +stop_stage=1 + +. shared/parse_options.sh || exit 1 + +if [ $# != 2 ]; then + echo "Wrong #arguments ($#, expected 2)" + echo "Usage: local/prepare_icmc_gss.sh [options] " + echo "e.g. local/prepare_icmc_gss.sh data/manifests exp/ami_gss" + echo "main options (for others, see top of script file)" + echo " --nj # number of parallel jobs" + echo " --stage # stage to start running from" + exit 1; +fi + +DATA_DIR=$1 +EXP_DIR=$2 + +mkdir -p $EXP_DIR + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare cut sets" + for part in train dev; do + lhotse cut simple \ + -r $DATA_DIR/icmcasr-mdm_recordings_${part}.jsonl.gz \ + -s $DATA_DIR/icmcasr-mdm_supervisions_${part}.jsonl.gz \ + $EXP_DIR/cuts_${part}.jsonl.gz + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" + for part in train dev; do + lhotse cut trim-to-supervisions --discard-overlapping \ + $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz + done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Split manifests for multi-GPU processing (optional)" + for part in train dev; do + gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Enhance train segments using GSS (requires GPU)" + # for train, we use smaller context and larger batches to speed-up processing + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 5.0 \ + --use-garbage-class \ + --channels 0,1,2,3 \ + --min-segment-length 0.05 \ + --max-segment-length 25.0 \ + --max-batch-duration 60.0 \ + --num-buckets 4 \ + --num-workers 4 + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Enhance eval/test segments using GSS (using GPU)" + # for eval/test, we use larger context and smaller batches to get better quality + for part in dev; do + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \ + $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 15.0 \ + --use-garbage-class \ + --min-segment-length 0.05 \ + --max-segment-length 16.0 \ + --max-batch-duration 45.0 \ + --num-buckets 4 \ + --num-workers 4 + done + done +fi + + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare manifests for GSS-enhanced data" + python3 local/prepare_icmc_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05 +fi diff --git a/egs/icmcasr/ASR/prepare.sh b/egs/icmcasr/ASR/prepare.sh index 0b8bbd676..9f1130a18 100755 --- a/egs/icmcasr/ASR/prepare.sh +++ b/egs/icmcasr/ASR/prepare.sh @@ -58,7 +58,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/icmcasr if [ ! -f data/manifests/.icmcasr_manifests.done ]; then mkdir -p data/manifests - for part in ihm sdm; do + for part in ihm sdm mdm; do lhotse prepare icmcasr --mic ${part} $dl_dir/ICMC-ASR data/manifests done touch data/manifests/.icmcasr_manifests.done @@ -77,15 +77,24 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)" + # We assume that you have installed the GSS package: https://github.com/desh2608/gss + local/prepare_icmc_gss.sh --stage 1 --stop_stage 6 data/manifests exp/icmc_gss +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 3: Compute fbank for icmcasr" if [ ! -f data/fbank/.icmcasr.done ]; then mkdir -p data/fbank ./local/compute_fbank_icmcasr.py --perturb-speed True + echo "Combining manifests" + lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ + gzip -c > data/manifests/cuts_train_all.jsonl.gz touch data/fbank/.icmcasr.done fi fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 4: Compute fbank for musan" if [ ! -f data/fbank/.msuan.done ]; then mkdir -p data/fbank @@ -95,7 +104,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi lang_phone_dir=data/lang_phone -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 5: Prepare phone based lang" mkdir -p $lang_phone_dir @@ -111,7 +120,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi lang_char_dir=data/lang_char -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 6: Prepare char based lang" mkdir -p $lang_char_dir # We reuse words.txt from phone based lexicon @@ -142,7 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi fi -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 7: Prepare Byte BPE based lang" for vocab_size in ${vocab_sizes[@]}; do