fix style

This commit is contained in:
marcoyang 2024-03-26 10:24:03 +08:00
parent 9c4db1b3fb
commit 4bce81bab1

View File

@ -1,72 +1,70 @@
import argparse import argparse
import csv import csv
import glob
import logging
import torch import torch
import torchaudio from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
import logging
import glob
from lhotse import load_manifest, CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.cut import MonoCut
from lhotse.audio import Recording from lhotse.audio import Recording
from lhotse.cut import MonoCut
from lhotse.supervision import SupervisionSegment from lhotse.supervision import SupervisionSegment
from argparse import ArgumentParser
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def parse_csv(csv_file="downloads/audioset/full_train_asedata_with_duration.csv"): def parse_csv(csv_file="downloads/audioset/full_train_asedata_with_duration.csv"):
mapping = {} mapping = {}
with open(csv_file, 'r') as fin: with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t") reader = csv.reader(fin, delimiter="\t")
for i, row in enumerate(reader): for i, row in enumerate(reader):
if i == 0: if i == 0:
continue continue
key = "/".join(row[0].split('/')[-2:]) key = "/".join(row[0].split("/")[-2:])
mapping[key] = row[1] mapping[key] = row[1]
return mapping return mapping
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument( parser.add_argument("--dataset-dir", type=str, default="downloads/audioset")
"--dataset-dir",
type=str,
default="downloads/audioset"
)
parser.add_argument( parser.add_argument(
"--split", "--split",
type=str, type=str,
default="balanced", default="balanced",
choices=["balanced", "unbalanced", "eval", "eval_all"] choices=["balanced", "unbalanced", "eval", "eval_all"],
) )
parser.add_argument( parser.add_argument(
"--feat-output-dir", "--feat-output-dir",
type=str, type=str,
default="data/fbank_audioset", default="data/fbank_audioset",
) )
return parser return parser
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
dataset_dir = args.dataset_dir dataset_dir = args.dataset_dir
split = args.split split = args.split
feat_output_dir = args.feat_output_dir feat_output_dir = args.feat_output_dir
num_jobs = 15 num_jobs = 15
num_mel_bins = 80 num_mel_bins = 80
import pdb; pdb.set_trace() import pdb
pdb.set_trace()
if split in ["balanced", "unbalanced"]: if split in ["balanced", "unbalanced"]:
csv_file = "downloads/audioset/full_train_asedata_with_duration.csv" csv_file = "downloads/audioset/full_train_asedata_with_duration.csv"
elif split == "eval": elif split == "eval":
@ -79,10 +77,10 @@ def main():
labels = parse_csv(csv_file) labels = parse_csv(csv_file)
audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav") audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav")
new_cuts = [] new_cuts = []
for i, audio in enumerate(audio_files): for i, audio in enumerate(audio_files):
cut_id = "/".join(audio.split('/')[-2:]) cut_id = "/".join(audio.split("/")[-2:])
recording = Recording.from_file(audio, cut_id) recording = Recording.from_file(audio, cut_id)
cut = MonoCut( cut = MonoCut(
id=cut_id, id=cut_id,
@ -105,7 +103,7 @@ def main():
supervision.audio_event = "" supervision.audio_event = ""
cut.supervisions = [supervision] cut.supervisions = [supervision]
new_cuts.append(cut) new_cuts.append(cut)
if i % 100 == 0 and i: if i % 100 == 0 and i:
logging.info(f"Processed {i} cuts until now.") logging.info(f"Processed {i} cuts until now.")
@ -122,14 +120,15 @@ def main():
executor=ex, executor=ex,
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
) )
manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz" manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz"
logging.info(f"Storing the manifest to {manifest_output_dir}") logging.info(f"Storing the manifest to {manifest_output_dir}")
cuts.to_jsonl(manifest_output_dir) cuts.to_jsonl(manifest_output_dir)
if __name__=="__main__":
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()