add gss support

This commit is contained in:
Yuekai Zhang 2023-11-27 14:51:30 +08:00
parent e9de3fb289
commit 22f68dd344
4 changed files with 310 additions and 35 deletions

View File

@ -64,12 +64,12 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
suffix="jsonl.gz",
)
# For GSS we already have cuts so we read them directly.
# manifests_gss = read_manifests_if_cached(
# dataset_parts=["train", "dev"],
# output_dir=src_dir,
# prefix="icmcasr-gss",
# suffix="jsonl.gz",
# )
manifests_gss = read_manifests_if_cached(
dataset_parts=["train", "dev"],
output_dir=src_dir,
prefix="icmcasr-gss",
suffix="jsonl.gz",
)
sampling_rate = 16000
@ -84,6 +84,10 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
def _extract_feats(
cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool
) -> None:
# check if the features have already been computed
if storage_path.exists() or storage_path.with_suffix(".lca").exists():
logging.info(f"{storage_path} exists, skipping feature extraction")
return
if speed_perturb:
logging.info(f"Doing speed perturb")
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
@ -136,17 +140,17 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
)
logging.info("Processing train split GSS")
# cuts_gss = (
# CutSet.from_manifests(**manifests_gss["train"])
# .trim_to_supervisions(keep_overlapping=False)
# .modify_ids(lambda x: x + "-gss")
# )
# _extract_feats(
# cuts_gss,
# output_dir / "feats_train_gss",
# src_dir / "cuts_train_gss.jsonl.gz",
# perturb_speed,
# )
cuts_gss = (
CutSet.from_manifests(**manifests_gss["train"])
.trim_to_supervisions(keep_overlapping=False)
.modify_ids(lambda x: x + "-gss")
)
_extract_feats(
cuts_gss,
output_dir / "feats_train_gss",
src_dir / "cuts_train_gss.jsonl.gz",
perturb_speed,
)
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
for split in ["dev"]:
@ -176,19 +180,19 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
storage_type=LilcomChunkyWriter,
)
)
# logging.info(f"Processing {split} GSS")
# cuts_gss = (
# CutSet.from_manifests(**manifests_gss[split])
# .trim_to_supervisions(keep_overlapping=False)
# .compute_and_store_features_batch(
# extractor=extractor,
# storage_path=output_dir / f"feats_{split}_gss",
# manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
# batch_duration=500,
# num_workers=4,
# storage_type=LilcomChunkyWriter,
# )
# )
logging.info(f"Processing {split} GSS")
cuts_gss = (
CutSet.from_manifests(**manifests_gss[split])
.trim_to_supervisions(keep_overlapping=False)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{split}_gss",
manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)

View File

@ -0,0 +1,163 @@
#!/usr/local/bin/python
# -*- coding: utf-8 -*-
# Data preparation for AliMeeting GSS-enhanced dataset.
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import re
from lhotse import Recording, RecordingSet, SupervisionSet
from lhotse.qa import fix_manifests
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.utils import fastcopy
from tqdm import tqdm
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
def get_args():
import argparse
parser = argparse.ArgumentParser(description="ICMC enhanced dataset preparation.")
parser.add_argument(
"manifests_dir",
type=Path,
help="Path to directory containing ICMC manifests.",
)
parser.add_argument(
"enhanced_dir",
type=Path,
help="Path to enhanced data directory.",
)
parser.add_argument(
"--num-jobs",
"-j",
type=int,
default=1,
help="Number of parallel jobs to run.",
)
parser.add_argument(
"--min-segment-duration",
"-d",
type=float,
default=0.0,
help="Minimum duration of a segment in seconds.",
)
return parser.parse_args()
def find_recording_and_create_new_supervision(enhanced_dir, supervision):
"""
Given a supervision (corresponding to original AMI recording), this function finds the
enhanced recording correspoding to the supervision, and returns this recording and
a new supervision whose start and end times are adjusted to match the enhanced recording.
"""
file_name = Path(
f"{supervision.recording_id}-{supervision.speaker}-{round(100*supervision.start):06d}_{round(100*supervision.end):06d}.flac"
)
save_path = str(enhanced_dir / f"{supervision.recording_id}" / file_name)
# replace re template DX0*C with DXmixC
save_path = re.sub(r"DX0(\d)", r"DXmix", save_path)
# convert it to Path object
save_path = Path(save_path)
if save_path.exists():
recording = Recording.from_file(save_path)
if recording.duration == 0:
logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
return None
# Old supervision is wrt to the original recording, we create new supervision
# wrt to the enhanced segment
new_supervision = fastcopy(
supervision,
recording_id=recording.id,
start=0,
duration=recording.duration,
)
return recording, new_supervision
else:
logging.warning(f"{save_path} does not exist.")
return None
def main(args):
# Get arguments
manifests_dir = args.manifests_dir
enhanced_dir = args.enhanced_dir
# Load manifests from cache if they exist (saves time)
manifests = read_manifests_if_cached(
dataset_parts=["train", "dev"],
output_dir=manifests_dir,
prefix="icmcasr-sdm",
suffix="jsonl.gz",
)
if not manifests:
raise ValueError(
"AliMeeting SDM manifests not found in {}".format(manifests_dir)
)
with ThreadPoolExecutor(args.num_jobs) as ex:
for part in ["train", "dev",]:
logging.info(f"Processing {part}...")
supervisions_orig = manifests[part]["supervisions"].filter(
lambda s: s.duration >= args.min_segment_duration
)
futures = []
for supervision in tqdm(
supervisions_orig,
desc="Distributing tasks",
):
futures.append(
ex.submit(
find_recording_and_create_new_supervision,
enhanced_dir,
supervision,
)
)
recordings = []
supervisions = []
for future in tqdm(
futures,
total=len(futures),
desc="Processing tasks",
):
result = future.result()
if result is not None:
recording, new_supervision = result
recordings.append(recording)
supervisions.append(new_supervision)
# Remove duplicates from the recordings
recordings_nodup = {}
for recording in recordings:
if recording.id not in recordings_nodup:
recordings_nodup[recording.id] = recording
else:
logging.warning("Recording {} is duplicated.".format(recording.id))
recordings = RecordingSet.from_recordings(recordings_nodup.values())
supervisions = SupervisionSet.from_segments(supervisions)
recordings, supervisions = fix_manifests(
recordings=recordings, supervisions=supervisions
)
logging.info(f"Writing {part} enhanced manifests")
recordings.to_file(
manifests_dir / f"icmcasr-gss_recordings_{part}.jsonl.gz"
)
supervisions.to_file(
manifests_dir / f"icmcasr-gss_supervisions_{part}.jsonl.gz"
)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,99 @@
#!/bin/bash
# This script is used to run GSS-based enhancement on AMI data.
set -euo pipefail
nj=1
stage=0
stop_stage=1
. shared/parse_options.sh || exit 1
if [ $# != 2 ]; then
echo "Wrong #arguments ($#, expected 2)"
echo "Usage: local/prepare_icmc_gss.sh [options] <data-dir> <exp-dir>"
echo "e.g. local/prepare_icmc_gss.sh data/manifests exp/ami_gss"
echo "main options (for others, see top of script file)"
echo " --nj <nj> # number of parallel jobs"
echo " --stage <stage> # stage to start running from"
exit 1;
fi
DATA_DIR=$1
EXP_DIR=$2
mkdir -p $EXP_DIR
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare cut sets"
for part in train dev; do
lhotse cut simple \
-r $DATA_DIR/icmcasr-mdm_recordings_${part}.jsonl.gz \
-s $DATA_DIR/icmcasr-mdm_supervisions_${part}.jsonl.gz \
$EXP_DIR/cuts_${part}.jsonl.gz
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
for part in train dev; do
lhotse cut trim-to-supervisions --discard-overlapping \
$EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split manifests for multi-GPU processing (optional)"
for part in train dev; do
gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj
done
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Enhance train segments using GSS (requires GPU)"
# for train, we use smaller context and larger batches to speed-up processing
for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
$EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \
--bss-iterations 10 \
--context-duration 5.0 \
--use-garbage-class \
--channels 0,1,2,3 \
--min-segment-length 0.05 \
--max-segment-length 25.0 \
--max-batch-duration 60.0 \
--num-buckets 4 \
--num-workers 4
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Enhance eval/test segments using GSS (using GPU)"
# for eval/test, we use larger context and smaller batches to get better quality
for part in dev; do
for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \
$EXP_DIR/enhanced \
--bss-iterations 10 \
--context-duration 15.0 \
--use-garbage-class \
--min-segment-length 0.05 \
--max-segment-length 16.0 \
--max-batch-duration 45.0 \
--num-buckets 4 \
--num-workers 4
done
done
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare manifests for GSS-enhanced data"
python3 local/prepare_icmc_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
fi

View File

@ -58,7 +58,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# to $dl_dir/icmcasr
if [ ! -f data/manifests/.icmcasr_manifests.done ]; then
mkdir -p data/manifests
for part in ihm sdm; do
for part in ihm sdm mdm; do
lhotse prepare icmcasr --mic ${part} $dl_dir/ICMC-ASR data/manifests
done
touch data/manifests/.icmcasr_manifests.done
@ -77,15 +77,24 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
# We assume that you have installed the GSS package: https://github.com/desh2608/gss
local/prepare_icmc_gss.sh --stage 1 --stop_stage 6 data/manifests exp/icmc_gss
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 3: Compute fbank for icmcasr"
if [ ! -f data/fbank/.icmcasr.done ]; then
mkdir -p data/fbank
./local/compute_fbank_icmcasr.py --perturb-speed True
echo "Combining manifests"
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
gzip -c > data/manifests/cuts_train_all.jsonl.gz
touch data/fbank/.icmcasr.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank
@ -95,7 +104,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
lang_phone_dir=data/lang_phone
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 5: Prepare phone based lang"
mkdir -p $lang_phone_dir
@ -111,7 +120,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
lang_char_dir=data/lang_char
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 6: Prepare char based lang"
mkdir -p $lang_char_dir
# We reuse words.txt from phone based lexicon
@ -142,7 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 7: Prepare Byte BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do