From 3e22108c67a8f6909d09c2b254bb7c813555bd84 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 20 Mar 2024 17:09:26 +0800 Subject: [PATCH] update the manifest --- .../AT/local/generate_audioset_manifest.py | 135 ++++++++++++++++++ egs/audioset/AT/zipformer/at_datamodule.py | 14 +- 2 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 egs/audioset/AT/local/generate_audioset_manifest.py diff --git a/egs/audioset/AT/local/generate_audioset_manifest.py b/egs/audioset/AT/local/generate_audioset_manifest.py new file mode 100644 index 000000000..321144b9a --- /dev/null +++ b/egs/audioset/AT/local/generate_audioset_manifest.py @@ -0,0 +1,135 @@ +import argparse +import csv + +import torch +import torchaudio +import logging +import glob +from lhotse import load_manifest, CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.cut import MonoCut +from lhotse.audio import Recording +from lhotse.supervision import SupervisionSegment +from argparse import ArgumentParser + +from icefall.utils import get_executor, str2bool + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +def parse_csv(csv_file="downloads/audioset/full_train_asedata_with_duration.csv"): + + mapping = {} + with open(csv_file, 'r') as fin: + reader = csv.reader(fin, delimiter="\t") + for i, row in enumerate(reader): + if i == 0: + continue + key = "/".join(row[0].split('/')[-2:]) + mapping[key] = row[1] + return mapping + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--dataset-dir", + type=str, + default="downloads/audioset" + ) + + parser.add_argument( + "--split", + type=str, + default="balanced", + choices=["balanced", "unbalanced", "eval", "eval_all"] + ) + + parser.add_argument( + "--feat-output-dir", + type=str, + default="data/fbank_audioset", + ) + + return parser + +def main(): + parser = get_parser() + args = parser.parse_args() + + dataset_dir = args.dataset_dir + split = args.split + feat_output_dir = args.feat_output_dir + + num_jobs = 15 + num_mel_bins = 80 + + import pdb; pdb.set_trace() + if split in ["balanced", "unbalanced"]: + csv_file = "downloads/audioset/full_train_asedata_with_duration.csv" + elif split == "eval": + csv_file = "downloads/audioset/eval.csv" + elif split == "eval_all": + csv_file = "downloads/audioset/eval_all.csv" + else: + raise ValueError() + + labels = parse_csv(csv_file) + + audio_files = glob.glob(f"{dataset_dir}/eval/wav_all/*.wav") + + new_cuts = [] + for i, audio in enumerate(audio_files): + cut_id = "/".join(audio.split('/')[-2:]) + recording = Recording.from_file(audio, cut_id) + cut = MonoCut( + id=cut_id, + start=0.0, + duration=recording.duration, + channel=0, + recording=recording, + ) + supervision = SupervisionSegment( + id=cut_id, + recording_id=cut.recording.id, + start=0.0, + channel=0, + duration=cut.duration, + ) + try: + supervision.audio_event = labels[cut_id] + except KeyError: + logging.info(f"No labels found for {cut_id}.") + supervision.audio_event = "" + cut.supervisions = [supervision] + new_cuts.append(cut) + + if i % 100 == 0 and i: + logging.info(f"Processed {i} cuts until now.") + + cuts = CutSet.from_cuts(new_cuts) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + logging.info(f"Computing fbank features for {split}") + with get_executor() as ex: + cuts = cuts.compute_and_store_features( + extractor=extractor, + storage_path=f"{feat_output_dir}/{split}_{args.split}_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz" + + logging.info(f"Storing the manifest to {manifest_output_dir}") + cuts.to_jsonl(manifest_output_dir) + +if __name__=="__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() \ No newline at end of file diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index 2f9cb1f78..77483a6b2 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -396,13 +396,21 @@ class AudioSetATDatamodule: @lru_cache() def audioset_train_cuts(self) -> CutSet: logging.info("About to get the audioset cuts for KD.") - cuts = load_manifest_lazy( + balanced_cuts = load_manifest_lazy( self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" ) - if self.args.audioset_subset == "unbalanced": - cuts += load_manifest_lazy( + if self.args.audioset_subset == "full": + unbalanced_cuts = load_manifest_lazy( self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" ) + cuts = CutSet.mux( + balanced_cuts, + unbalanced_cuts, + weights=[20000, 2000000], + stop_early=True, + ) + else: + cuts = balanced_cuts return cuts @lru_cache()