diff --git a/egs/audioset/AT/local/generate_audioset_manifest.py b/egs/audioset/AT/local/generate_audioset_manifest.py index 8d4f4ec98..1c5b3457c 100644 --- a/egs/audioset/AT/local/generate_audioset_manifest.py +++ b/egs/audioset/AT/local/generate_audioset_manifest.py @@ -25,6 +25,7 @@ import csv import glob import logging import os +from typing import Dict import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter @@ -38,21 +39,38 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def parse_csv(csv_file): +def get_ID_mapping(csv_file): + # get a mapping between class ID and class name + mapping = {} + with open(csv_file, "r") as fin: + reader = csv.reader(fin, delimiter=",") + for i, row in enumerate(reader): + if i == 0: + continue + mapping[row[1]] = row[0] + return mapping + + +def parse_csv(csv_file: str, id_mapping: Dict): # The content of the csv file shoud be something like this # ------------------------------------------------------ # filename label # dataset/AudioSet/balanced/xxxx.wav 0;451 # dataset/AudioSet/balanced/xxxy.wav 375 # ------------------------------------------------------ + + def name2id(names): + ids = [id_mapping[name] for name in names.split(",")] + return ";".join(ids) + mapping = {} with open(csv_file, "r") as fin: - reader = csv.reader(fin, delimiter="\t") + reader = csv.reader(fin, delimiter=" ") for i, row in enumerate(reader): - if i == 0: + if i <= 2: continue - key = "/".join(row[0].split("/")[-2:]) - mapping[key] = row[1] + key = row[0].replace(",", "") + mapping[key] = name2id(row[-1]) return mapping @@ -67,7 +85,7 @@ def get_parser(): "--split", type=str, default="balanced", - choices=["balanced", "unbalanced", "eval", "eval_all"], + choices=["balanced", "unbalanced", "eval"], ) parser.add_argument( @@ -91,21 +109,21 @@ def main(): num_mel_bins = 80 if split in ["balanced", "unbalanced"]: - csv_file = "downloads/audioset/full_train_asedata_with_duration.csv" + csv_file = f"{dataset_dir}/{split}_train_segments.csv" elif split == "eval": - csv_file = "downloads/audioset/eval.csv" - elif split == "eval_all": - csv_file = "downloads/audioset/eval_all.csv" + csv_file = f"{dataset_dir}/eval_segments.csv" else: raise ValueError() - labels = parse_csv(csv_file) + class_indices_csv = f"{dataset_dir}/class_labels_indices.csv" + id_mapping = get_ID_mapping(class_indices_csv) + labels = parse_csv(csv_file, id_mapping) - audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav") + audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav") new_cuts = [] for i, audio in enumerate(audio_files): - cut_id = "/".join(audio.split("/")[-2:]) + cut_id = audio.split("/")[-1].split("_")[0] recording = Recording.from_file(audio, cut_id) cut = MonoCut( id=cut_id, @@ -140,7 +158,7 @@ def main(): with get_executor() as ex: cuts = cuts.compute_and_store_features( extractor=extractor, - storage_path=f"{feat_output_dir}/{split}_{split}_feats", + storage_path=f"{feat_output_dir}/{split}_feats", num_jobs=num_jobs if ex is None else 80, executor=ex, storage_type=LilcomChunkyWriter,