enhance documentation

This commit is contained in:
marcoyang 2024-03-26 14:56:29 +08:00
parent 7a8c9b7f53
commit f4c187286a
2 changed files with 36 additions and 11 deletions

View File

@ -1,7 +1,30 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file generates the manifest and computes the fbank features for AudioSet
dataset. The generated manifests and features are stored in data/fbank.
"""
import argparse
import csv
import glob
import logging
import os
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
@ -15,8 +38,13 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def parse_csv(csv_file="downloads/audioset/full_train_asedata_with_duration.csv"):
def parse_csv(csv_file):
# The content of the csv file shoud be something like this
# ------------------------------------------------------
# filename label
# dataset/AudioSet/balanced/xxxx.wav 0;451
# dataset/AudioSet/balanced/xxxy.wav 375
# ------------------------------------------------------
mapping = {}
with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t")
@ -45,7 +73,7 @@ def get_parser():
parser.add_argument(
"--feat-output-dir",
type=str,
default="data/fbank_audioset",
default="data/fbank",
)
return parser
@ -59,12 +87,9 @@ def main():
split = args.split
feat_output_dir = args.feat_output_dir
num_jobs = 15
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
import pdb
pdb.set_trace()
if split in ["balanced", "unbalanced"]:
csv_file = "downloads/audioset/full_train_asedata_with_duration.csv"
elif split == "eval":
@ -100,7 +125,7 @@ def main():
supervision.audio_event = labels[cut_id]
except KeyError:
logging.info(f"No labels found for {cut_id}.")
supervision.audio_event = ""
continue
cut.supervisions = [supervision]
new_cuts.append(cut)
@ -115,7 +140,7 @@ def main():
with get_executor() as ex:
cuts = cuts.compute_and_store_features(
extractor=extractor,
storage_path=f"{feat_output_dir}/{split}_{args.split}_feats",
storage_path=f"{feat_output_dir}/{split}_{split}_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,

View File

@ -56,7 +56,7 @@ class AudioSetATDatamodule:
DataModule for k2 audio tagging (AT) experiments.
It contains all the common data pipeline modules used in ASR
It contains all the common data pipeline modules used in AT
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
@ -64,7 +64,7 @@ class AudioSetATDatamodule:
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
This class should be derived for specific corpora used in AT tasks.
"""
def __init__(self, args: argparse.Namespace):