icefall/egs/audioset/AT/local/generate_audioset_manifest.py
2024-03-26 10:24:03 +08:00

135 lines
3.6 KiB
Python

import argparse
import csv
import glob
import logging
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.audio import Recording
from lhotse.cut import MonoCut
from lhotse.supervision import SupervisionSegment
from icefall.utils import get_executor
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def parse_csv(csv_file="downloads/audioset/full_train_asedata_with_duration.csv"):
mapping = {}
with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t")
for i, row in enumerate(reader):
if i == 0:
continue
key = "/".join(row[0].split("/")[-2:])
mapping[key] = row[1]
return mapping
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--dataset-dir", type=str, default="downloads/audioset")
parser.add_argument(
"--split",
type=str,
default="balanced",
choices=["balanced", "unbalanced", "eval", "eval_all"],
)
parser.add_argument(
"--feat-output-dir",
type=str,
default="data/fbank_audioset",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
dataset_dir = args.dataset_dir
split = args.split
feat_output_dir = args.feat_output_dir
num_jobs = 15
num_mel_bins = 80
import pdb
pdb.set_trace()
if split in ["balanced", "unbalanced"]:
csv_file = "downloads/audioset/full_train_asedata_with_duration.csv"
elif split == "eval":
csv_file = "downloads/audioset/eval.csv"
elif split == "eval_all":
csv_file = "downloads/audioset/eval_all.csv"
else:
raise ValueError()
labels = parse_csv(csv_file)
audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav")
new_cuts = []
for i, audio in enumerate(audio_files):
cut_id = "/".join(audio.split("/")[-2:])
recording = Recording.from_file(audio, cut_id)
cut = MonoCut(
id=cut_id,
start=0.0,
duration=recording.duration,
channel=0,
recording=recording,
)
supervision = SupervisionSegment(
id=cut_id,
recording_id=cut.recording.id,
start=0.0,
channel=0,
duration=cut.duration,
)
try:
supervision.audio_event = labels[cut_id]
except KeyError:
logging.info(f"No labels found for {cut_id}.")
supervision.audio_event = ""
cut.supervisions = [supervision]
new_cuts.append(cut)
if i % 100 == 0 and i:
logging.info(f"Processed {i} cuts until now.")
cuts = CutSet.from_cuts(new_cuts)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
logging.info(f"Computing fbank features for {split}")
with get_executor() as ex:
cuts = cuts.compute_and_store_features(
extractor=extractor,
storage_path=f"{feat_output_dir}/{split}_{args.split}_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz"
logging.info(f"Storing the manifest to {manifest_output_dir}")
cuts.to_jsonl(manifest_output_dir)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()