mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
add gss support
This commit is contained in:
parent
e9de3fb289
commit
22f68dd344
@ -64,12 +64,12 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
suffix="jsonl.gz",
|
suffix="jsonl.gz",
|
||||||
)
|
)
|
||||||
# For GSS we already have cuts so we read them directly.
|
# For GSS we already have cuts so we read them directly.
|
||||||
# manifests_gss = read_manifests_if_cached(
|
manifests_gss = read_manifests_if_cached(
|
||||||
# dataset_parts=["train", "dev"],
|
dataset_parts=["train", "dev"],
|
||||||
# output_dir=src_dir,
|
output_dir=src_dir,
|
||||||
# prefix="icmcasr-gss",
|
prefix="icmcasr-gss",
|
||||||
# suffix="jsonl.gz",
|
suffix="jsonl.gz",
|
||||||
# )
|
)
|
||||||
|
|
||||||
sampling_rate = 16000
|
sampling_rate = 16000
|
||||||
|
|
||||||
@ -84,6 +84,10 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
def _extract_feats(
|
def _extract_feats(
|
||||||
cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool
|
cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# check if the features have already been computed
|
||||||
|
if storage_path.exists() or storage_path.with_suffix(".lca").exists():
|
||||||
|
logging.info(f"{storage_path} exists, skipping feature extraction")
|
||||||
|
return
|
||||||
if speed_perturb:
|
if speed_perturb:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info(f"Doing speed perturb")
|
||||||
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
||||||
@ -136,17 +140,17 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Processing train split GSS")
|
logging.info("Processing train split GSS")
|
||||||
# cuts_gss = (
|
cuts_gss = (
|
||||||
# CutSet.from_manifests(**manifests_gss["train"])
|
CutSet.from_manifests(**manifests_gss["train"])
|
||||||
# .trim_to_supervisions(keep_overlapping=False)
|
.trim_to_supervisions(keep_overlapping=False)
|
||||||
# .modify_ids(lambda x: x + "-gss")
|
.modify_ids(lambda x: x + "-gss")
|
||||||
# )
|
)
|
||||||
# _extract_feats(
|
_extract_feats(
|
||||||
# cuts_gss,
|
cuts_gss,
|
||||||
# output_dir / "feats_train_gss",
|
output_dir / "feats_train_gss",
|
||||||
# src_dir / "cuts_train_gss.jsonl.gz",
|
src_dir / "cuts_train_gss.jsonl.gz",
|
||||||
# perturb_speed,
|
perturb_speed,
|
||||||
# )
|
)
|
||||||
|
|
||||||
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
|
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
|
||||||
for split in ["dev"]:
|
for split in ["dev"]:
|
||||||
@ -176,19 +180,19 @@ def compute_fbank_icmcasr(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# logging.info(f"Processing {split} GSS")
|
logging.info(f"Processing {split} GSS")
|
||||||
# cuts_gss = (
|
cuts_gss = (
|
||||||
# CutSet.from_manifests(**manifests_gss[split])
|
CutSet.from_manifests(**manifests_gss[split])
|
||||||
# .trim_to_supervisions(keep_overlapping=False)
|
.trim_to_supervisions(keep_overlapping=False)
|
||||||
# .compute_and_store_features_batch(
|
.compute_and_store_features_batch(
|
||||||
# extractor=extractor,
|
extractor=extractor,
|
||||||
# storage_path=output_dir / f"feats_{split}_gss",
|
storage_path=output_dir / f"feats_{split}_gss",
|
||||||
# manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
|
manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
|
||||||
# batch_duration=500,
|
batch_duration=500,
|
||||||
# num_workers=4,
|
num_workers=4,
|
||||||
# storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
# )
|
)
|
||||||
# )
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
163
egs/icmcasr/ASR/local/prepare_icmc_enhanced.py
Normal file
163
egs/icmcasr/ASR/local/prepare_icmc_enhanced.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
#!/usr/local/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Data preparation for AliMeeting GSS-enhanced dataset.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
from lhotse import Recording, RecordingSet, SupervisionSet
|
||||||
|
from lhotse.qa import fix_manifests
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
from lhotse.utils import fastcopy
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s %(levelname)-8s %(message)s",
|
||||||
|
level=logging.INFO,
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="ICMC enhanced dataset preparation.")
|
||||||
|
parser.add_argument(
|
||||||
|
"manifests_dir",
|
||||||
|
type=Path,
|
||||||
|
help="Path to directory containing ICMC manifests.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"enhanced_dir",
|
||||||
|
type=Path,
|
||||||
|
help="Path to enhanced data directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-jobs",
|
||||||
|
"-j",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of parallel jobs to run.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-segment-duration",
|
||||||
|
"-d",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Minimum duration of a segment in seconds.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def find_recording_and_create_new_supervision(enhanced_dir, supervision):
|
||||||
|
"""
|
||||||
|
Given a supervision (corresponding to original AMI recording), this function finds the
|
||||||
|
enhanced recording correspoding to the supervision, and returns this recording and
|
||||||
|
a new supervision whose start and end times are adjusted to match the enhanced recording.
|
||||||
|
"""
|
||||||
|
file_name = Path(
|
||||||
|
f"{supervision.recording_id}-{supervision.speaker}-{round(100*supervision.start):06d}_{round(100*supervision.end):06d}.flac"
|
||||||
|
)
|
||||||
|
save_path = str(enhanced_dir / f"{supervision.recording_id}" / file_name)
|
||||||
|
# replace re template DX0*C with DXmixC
|
||||||
|
save_path = re.sub(r"DX0(\d)", r"DXmix", save_path)
|
||||||
|
# convert it to Path object
|
||||||
|
save_path = Path(save_path)
|
||||||
|
if save_path.exists():
|
||||||
|
recording = Recording.from_file(save_path)
|
||||||
|
if recording.duration == 0:
|
||||||
|
logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Old supervision is wrt to the original recording, we create new supervision
|
||||||
|
# wrt to the enhanced segment
|
||||||
|
new_supervision = fastcopy(
|
||||||
|
supervision,
|
||||||
|
recording_id=recording.id,
|
||||||
|
start=0,
|
||||||
|
duration=recording.duration,
|
||||||
|
)
|
||||||
|
return recording, new_supervision
|
||||||
|
else:
|
||||||
|
logging.warning(f"{save_path} does not exist.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Get arguments
|
||||||
|
manifests_dir = args.manifests_dir
|
||||||
|
enhanced_dir = args.enhanced_dir
|
||||||
|
|
||||||
|
# Load manifests from cache if they exist (saves time)
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
dataset_parts=["train", "dev"],
|
||||||
|
output_dir=manifests_dir,
|
||||||
|
prefix="icmcasr-sdm",
|
||||||
|
suffix="jsonl.gz",
|
||||||
|
)
|
||||||
|
if not manifests:
|
||||||
|
raise ValueError(
|
||||||
|
"AliMeeting SDM manifests not found in {}".format(manifests_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(args.num_jobs) as ex:
|
||||||
|
for part in ["train", "dev",]:
|
||||||
|
logging.info(f"Processing {part}...")
|
||||||
|
supervisions_orig = manifests[part]["supervisions"].filter(
|
||||||
|
lambda s: s.duration >= args.min_segment_duration
|
||||||
|
)
|
||||||
|
futures = []
|
||||||
|
|
||||||
|
for supervision in tqdm(
|
||||||
|
supervisions_orig,
|
||||||
|
desc="Distributing tasks",
|
||||||
|
):
|
||||||
|
futures.append(
|
||||||
|
ex.submit(
|
||||||
|
find_recording_and_create_new_supervision,
|
||||||
|
enhanced_dir,
|
||||||
|
supervision,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
recordings = []
|
||||||
|
supervisions = []
|
||||||
|
for future in tqdm(
|
||||||
|
futures,
|
||||||
|
total=len(futures),
|
||||||
|
desc="Processing tasks",
|
||||||
|
):
|
||||||
|
result = future.result()
|
||||||
|
if result is not None:
|
||||||
|
recording, new_supervision = result
|
||||||
|
recordings.append(recording)
|
||||||
|
supervisions.append(new_supervision)
|
||||||
|
|
||||||
|
# Remove duplicates from the recordings
|
||||||
|
recordings_nodup = {}
|
||||||
|
for recording in recordings:
|
||||||
|
if recording.id not in recordings_nodup:
|
||||||
|
recordings_nodup[recording.id] = recording
|
||||||
|
else:
|
||||||
|
logging.warning("Recording {} is duplicated.".format(recording.id))
|
||||||
|
recordings = RecordingSet.from_recordings(recordings_nodup.values())
|
||||||
|
supervisions = SupervisionSet.from_segments(supervisions)
|
||||||
|
|
||||||
|
recordings, supervisions = fix_manifests(
|
||||||
|
recordings=recordings, supervisions=supervisions
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Writing {part} enhanced manifests")
|
||||||
|
recordings.to_file(
|
||||||
|
manifests_dir / f"icmcasr-gss_recordings_{part}.jsonl.gz"
|
||||||
|
)
|
||||||
|
supervisions.to_file(
|
||||||
|
manifests_dir / f"icmcasr-gss_supervisions_{part}.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
main(args)
|
99
egs/icmcasr/ASR/local/prepare_icmc_gss.sh
Normal file
99
egs/icmcasr/ASR/local/prepare_icmc_gss.sh
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# This script is used to run GSS-based enhancement on AMI data.
|
||||||
|
set -euo pipefail
|
||||||
|
nj=1
|
||||||
|
stage=0
|
||||||
|
stop_stage=1
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ $# != 2 ]; then
|
||||||
|
echo "Wrong #arguments ($#, expected 2)"
|
||||||
|
echo "Usage: local/prepare_icmc_gss.sh [options] <data-dir> <exp-dir>"
|
||||||
|
echo "e.g. local/prepare_icmc_gss.sh data/manifests exp/ami_gss"
|
||||||
|
echo "main options (for others, see top of script file)"
|
||||||
|
echo " --nj <nj> # number of parallel jobs"
|
||||||
|
echo " --stage <stage> # stage to start running from"
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
DATA_DIR=$1
|
||||||
|
EXP_DIR=$2
|
||||||
|
|
||||||
|
mkdir -p $EXP_DIR
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Prepare cut sets"
|
||||||
|
for part in train dev; do
|
||||||
|
lhotse cut simple \
|
||||||
|
-r $DATA_DIR/icmcasr-mdm_recordings_${part}.jsonl.gz \
|
||||||
|
-s $DATA_DIR/icmcasr-mdm_supervisions_${part}.jsonl.gz \
|
||||||
|
$EXP_DIR/cuts_${part}.jsonl.gz
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
|
||||||
|
for part in train dev; do
|
||||||
|
lhotse cut trim-to-supervisions --discard-overlapping \
|
||||||
|
$EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Split manifests for multi-GPU processing (optional)"
|
||||||
|
for part in train dev; do
|
||||||
|
gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
|
||||||
|
$EXP_DIR/cuts_per_segment_${part}_split$nj
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Enhance train segments using GSS (requires GPU)"
|
||||||
|
# for train, we use smaller context and larger batches to speed-up processing
|
||||||
|
for JOB in $(seq $nj); do
|
||||||
|
gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
|
||||||
|
$EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \
|
||||||
|
--bss-iterations 10 \
|
||||||
|
--context-duration 5.0 \
|
||||||
|
--use-garbage-class \
|
||||||
|
--channels 0,1,2,3 \
|
||||||
|
--min-segment-length 0.05 \
|
||||||
|
--max-segment-length 25.0 \
|
||||||
|
--max-batch-duration 60.0 \
|
||||||
|
--num-buckets 4 \
|
||||||
|
--num-workers 4
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Enhance eval/test segments using GSS (using GPU)"
|
||||||
|
# for eval/test, we use larger context and smaller batches to get better quality
|
||||||
|
for part in dev; do
|
||||||
|
for JOB in $(seq $nj); do
|
||||||
|
gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
|
||||||
|
$EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \
|
||||||
|
$EXP_DIR/enhanced \
|
||||||
|
--bss-iterations 10 \
|
||||||
|
--context-duration 15.0 \
|
||||||
|
--use-garbage-class \
|
||||||
|
--min-segment-length 0.05 \
|
||||||
|
--max-segment-length 16.0 \
|
||||||
|
--max-batch-duration 45.0 \
|
||||||
|
--num-buckets 4 \
|
||||||
|
--num-workers 4
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 6: Prepare manifests for GSS-enhanced data"
|
||||||
|
python3 local/prepare_icmc_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
|
||||||
|
fi
|
@ -58,7 +58,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
# to $dl_dir/icmcasr
|
# to $dl_dir/icmcasr
|
||||||
if [ ! -f data/manifests/.icmcasr_manifests.done ]; then
|
if [ ! -f data/manifests/.icmcasr_manifests.done ]; then
|
||||||
mkdir -p data/manifests
|
mkdir -p data/manifests
|
||||||
for part in ihm sdm; do
|
for part in ihm sdm mdm; do
|
||||||
lhotse prepare icmcasr --mic ${part} $dl_dir/ICMC-ASR data/manifests
|
lhotse prepare icmcasr --mic ${part} $dl_dir/ICMC-ASR data/manifests
|
||||||
done
|
done
|
||||||
touch data/manifests/.icmcasr_manifests.done
|
touch data/manifests/.icmcasr_manifests.done
|
||||||
@ -77,15 +77,24 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
|
||||||
|
# We assume that you have installed the GSS package: https://github.com/desh2608/gss
|
||||||
|
local/prepare_icmc_gss.sh --stage 1 --stop_stage 6 data/manifests exp/icmc_gss
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 3: Compute fbank for icmcasr"
|
log "Stage 3: Compute fbank for icmcasr"
|
||||||
if [ ! -f data/fbank/.icmcasr.done ]; then
|
if [ ! -f data/fbank/.icmcasr.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_icmcasr.py --perturb-speed True
|
./local/compute_fbank_icmcasr.py --perturb-speed True
|
||||||
|
echo "Combining manifests"
|
||||||
|
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
|
||||||
|
gzip -c > data/manifests/cuts_train_all.jsonl.gz
|
||||||
touch data/fbank/.icmcasr.done
|
touch data/fbank/.icmcasr.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 4: Compute fbank for musan"
|
log "Stage 4: Compute fbank for musan"
|
||||||
if [ ! -f data/fbank/.msuan.done ]; then
|
if [ ! -f data/fbank/.msuan.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
@ -95,7 +104,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
lang_phone_dir=data/lang_phone
|
lang_phone_dir=data/lang_phone
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
log "Stage 5: Prepare phone based lang"
|
log "Stage 5: Prepare phone based lang"
|
||||||
mkdir -p $lang_phone_dir
|
mkdir -p $lang_phone_dir
|
||||||
|
|
||||||
@ -111,7 +120,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
lang_char_dir=data/lang_char
|
lang_char_dir=data/lang_char
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
log "Stage 6: Prepare char based lang"
|
log "Stage 6: Prepare char based lang"
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
# We reuse words.txt from phone based lexicon
|
# We reuse words.txt from phone based lexicon
|
||||||
@ -142,7 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
log "Stage 7: Prepare Byte BPE based lang"
|
log "Stage 7: Prepare Byte BPE based lang"
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
|
Loading…
x
Reference in New Issue
Block a user