From 4bce81bab15919e869055447f143c6e36c269f6f Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 26 Mar 2024 10:24:03 +0800 Subject: [PATCH] fix style --- .../AT/local/generate_audioset_manifest.py | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/egs/audioset/AT/local/generate_audioset_manifest.py b/egs/audioset/AT/local/generate_audioset_manifest.py index 321144b9a..060337a72 100644 --- a/egs/audioset/AT/local/generate_audioset_manifest.py +++ b/egs/audioset/AT/local/generate_audioset_manifest.py @@ -1,72 +1,70 @@ import argparse import csv +import glob +import logging import torch -import torchaudio -import logging -import glob -from lhotse import load_manifest, CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.cut import MonoCut +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.audio import Recording +from lhotse.cut import MonoCut from lhotse.supervision import SupervisionSegment -from argparse import ArgumentParser -from icefall.utils import get_executor, str2bool +from icefall.utils import get_executor torch.set_num_threads(1) torch.set_num_interop_threads(1) + def parse_csv(csv_file="downloads/audioset/full_train_asedata_with_duration.csv"): mapping = {} - with open(csv_file, 'r') as fin: + with open(csv_file, "r") as fin: reader = csv.reader(fin, delimiter="\t") for i, row in enumerate(reader): if i == 0: continue - key = "/".join(row[0].split('/')[-2:]) + key = "/".join(row[0].split("/")[-2:]) mapping[key] = row[1] return mapping - + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - - parser.add_argument( - "--dataset-dir", - type=str, - default="downloads/audioset" - ) - + + parser.add_argument("--dataset-dir", type=str, default="downloads/audioset") + parser.add_argument( "--split", type=str, default="balanced", - choices=["balanced", "unbalanced", "eval", "eval_all"] + choices=["balanced", "unbalanced", "eval", "eval_all"], ) - + parser.add_argument( "--feat-output-dir", type=str, default="data/fbank_audioset", ) - + return parser + def main(): parser = get_parser() args = parser.parse_args() - + dataset_dir = args.dataset_dir split = args.split feat_output_dir = args.feat_output_dir - + num_jobs = 15 num_mel_bins = 80 - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() if split in ["balanced", "unbalanced"]: csv_file = "downloads/audioset/full_train_asedata_with_duration.csv" elif split == "eval": @@ -79,10 +77,10 @@ def main(): labels = parse_csv(csv_file) audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav") - + new_cuts = [] for i, audio in enumerate(audio_files): - cut_id = "/".join(audio.split('/')[-2:]) + cut_id = "/".join(audio.split("/")[-2:]) recording = Recording.from_file(audio, cut_id) cut = MonoCut( id=cut_id, @@ -105,7 +103,7 @@ def main(): supervision.audio_event = "" cut.supervisions = [supervision] new_cuts.append(cut) - + if i % 100 == 0 and i: logging.info(f"Processed {i} cuts until now.") @@ -122,14 +120,15 @@ def main(): executor=ex, storage_type=LilcomChunkyWriter, ) - + manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz" logging.info(f"Storing the manifest to {manifest_output_dir}") cuts.to_jsonl(manifest_output_dir) -if __name__=="__main__": + +if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - main() \ No newline at end of file + main()