update the script to generate audioset manfiest

This commit is contained in:
marcoyang 2024-04-08 18:46:09 +08:00
parent 01b744f127
commit 25d22d9318

View File

@ -25,6 +25,7 @@ import csv
import glob
import logging
import os
from typing import Dict
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
@ -38,21 +39,38 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def parse_csv(csv_file):
def get_ID_mapping(csv_file):
# get a mapping between class ID and class name
mapping = {}
with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
mapping[row[1]] = row[0]
return mapping
def parse_csv(csv_file: str, id_mapping: Dict):
# The content of the csv file shoud be something like this
# ------------------------------------------------------
# filename label
# dataset/AudioSet/balanced/xxxx.wav 0;451
# dataset/AudioSet/balanced/xxxy.wav 375
# ------------------------------------------------------
def name2id(names):
ids = [id_mapping[name] for name in names.split(",")]
return ";".join(ids)
mapping = {}
with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t")
reader = csv.reader(fin, delimiter=" ")
for i, row in enumerate(reader):
if i == 0:
if i <= 2:
continue
key = "/".join(row[0].split("/")[-2:])
mapping[key] = row[1]
key = row[0].replace(",", "")
mapping[key] = name2id(row[-1])
return mapping
@ -67,7 +85,7 @@ def get_parser():
"--split",
type=str,
default="balanced",
choices=["balanced", "unbalanced", "eval", "eval_all"],
choices=["balanced", "unbalanced", "eval"],
)
parser.add_argument(
@ -91,21 +109,21 @@ def main():
num_mel_bins = 80
if split in ["balanced", "unbalanced"]:
csv_file = "downloads/audioset/full_train_asedata_with_duration.csv"
csv_file = f"{dataset_dir}/{split}_train_segments.csv"
elif split == "eval":
csv_file = "downloads/audioset/eval.csv"
elif split == "eval_all":
csv_file = "downloads/audioset/eval_all.csv"
csv_file = f"{dataset_dir}/eval_segments.csv"
else:
raise ValueError()
labels = parse_csv(csv_file)
class_indices_csv = f"{dataset_dir}/class_labels_indices.csv"
id_mapping = get_ID_mapping(class_indices_csv)
labels = parse_csv(csv_file, id_mapping)
audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav")
audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav")
new_cuts = []
for i, audio in enumerate(audio_files):
cut_id = "/".join(audio.split("/")[-2:])
cut_id = audio.split("/")[-1].split("_")[0]
recording = Recording.from_file(audio, cut_id)
cut = MonoCut(
id=cut_id,
@ -140,7 +158,7 @@ def main():
with get_executor() as ex:
cuts = cuts.compute_and_store_features(
extractor=extractor,
storage_path=f"{feat_output_dir}/{split}_{split}_feats",
storage_path=f"{feat_output_dir}/{split}_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,