mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
update the script to generate audioset manfiest
This commit is contained in:
parent
01b744f127
commit
25d22d9318
@ -25,6 +25,7 @@ import csv
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
@ -38,21 +39,38 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def parse_csv(csv_file):
|
||||
def get_ID_mapping(csv_file):
|
||||
# get a mapping between class ID and class name
|
||||
mapping = {}
|
||||
with open(csv_file, "r") as fin:
|
||||
reader = csv.reader(fin, delimiter=",")
|
||||
for i, row in enumerate(reader):
|
||||
if i == 0:
|
||||
continue
|
||||
mapping[row[1]] = row[0]
|
||||
return mapping
|
||||
|
||||
|
||||
def parse_csv(csv_file: str, id_mapping: Dict):
|
||||
# The content of the csv file shoud be something like this
|
||||
# ------------------------------------------------------
|
||||
# filename label
|
||||
# dataset/AudioSet/balanced/xxxx.wav 0;451
|
||||
# dataset/AudioSet/balanced/xxxy.wav 375
|
||||
# ------------------------------------------------------
|
||||
|
||||
def name2id(names):
|
||||
ids = [id_mapping[name] for name in names.split(",")]
|
||||
return ";".join(ids)
|
||||
|
||||
mapping = {}
|
||||
with open(csv_file, "r") as fin:
|
||||
reader = csv.reader(fin, delimiter="\t")
|
||||
reader = csv.reader(fin, delimiter=" ")
|
||||
for i, row in enumerate(reader):
|
||||
if i == 0:
|
||||
if i <= 2:
|
||||
continue
|
||||
key = "/".join(row[0].split("/")[-2:])
|
||||
mapping[key] = row[1]
|
||||
key = row[0].replace(",", "")
|
||||
mapping[key] = name2id(row[-1])
|
||||
return mapping
|
||||
|
||||
|
||||
@ -67,7 +85,7 @@ def get_parser():
|
||||
"--split",
|
||||
type=str,
|
||||
default="balanced",
|
||||
choices=["balanced", "unbalanced", "eval", "eval_all"],
|
||||
choices=["balanced", "unbalanced", "eval"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -91,21 +109,21 @@ def main():
|
||||
num_mel_bins = 80
|
||||
|
||||
if split in ["balanced", "unbalanced"]:
|
||||
csv_file = "downloads/audioset/full_train_asedata_with_duration.csv"
|
||||
csv_file = f"{dataset_dir}/{split}_train_segments.csv"
|
||||
elif split == "eval":
|
||||
csv_file = "downloads/audioset/eval.csv"
|
||||
elif split == "eval_all":
|
||||
csv_file = "downloads/audioset/eval_all.csv"
|
||||
csv_file = f"{dataset_dir}/eval_segments.csv"
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
labels = parse_csv(csv_file)
|
||||
class_indices_csv = f"{dataset_dir}/class_labels_indices.csv"
|
||||
id_mapping = get_ID_mapping(class_indices_csv)
|
||||
labels = parse_csv(csv_file, id_mapping)
|
||||
|
||||
audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav")
|
||||
audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav")
|
||||
|
||||
new_cuts = []
|
||||
for i, audio in enumerate(audio_files):
|
||||
cut_id = "/".join(audio.split("/")[-2:])
|
||||
cut_id = audio.split("/")[-1].split("_")[0]
|
||||
recording = Recording.from_file(audio, cut_id)
|
||||
cut = MonoCut(
|
||||
id=cut_id,
|
||||
@ -140,7 +158,7 @@ def main():
|
||||
with get_executor() as ex:
|
||||
cuts = cuts.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{feat_output_dir}/{split}_{split}_feats",
|
||||
storage_path=f"{feat_output_dir}/{split}_feats",
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
|
Loading…
x
Reference in New Issue
Block a user