enhance documentation

This commit is contained in:
marcoyang 2024-03-26 14:56:29 +08:00
parent 7a8c9b7f53
commit f4c187286a
2 changed files with 36 additions and 11 deletions

View File

@ -1,7 +1,30 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file generates the manifest and computes the fbank features for AudioSet
dataset. The generated manifests and features are stored in data/fbank.
"""
import argparse import argparse
import csv import csv
import glob import glob
import logging import logging
import os
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
@ -15,8 +38,13 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def parse_csv(csv_file="downloads/audioset/full_train_asedata_with_duration.csv"): def parse_csv(csv_file):
# The content of the csv file shoud be something like this
# ------------------------------------------------------
# filename label
# dataset/AudioSet/balanced/xxxx.wav 0;451
# dataset/AudioSet/balanced/xxxy.wav 375
# ------------------------------------------------------
mapping = {} mapping = {}
with open(csv_file, "r") as fin: with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t") reader = csv.reader(fin, delimiter="\t")
@ -45,7 +73,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--feat-output-dir", "--feat-output-dir",
type=str, type=str,
default="data/fbank_audioset", default="data/fbank",
) )
return parser return parser
@ -59,12 +87,9 @@ def main():
split = args.split split = args.split
feat_output_dir = args.feat_output_dir feat_output_dir = args.feat_output_dir
num_jobs = 15 num_jobs = min(15, os.cpu_count())
num_mel_bins = 80 num_mel_bins = 80
import pdb
pdb.set_trace()
if split in ["balanced", "unbalanced"]: if split in ["balanced", "unbalanced"]:
csv_file = "downloads/audioset/full_train_asedata_with_duration.csv" csv_file = "downloads/audioset/full_train_asedata_with_duration.csv"
elif split == "eval": elif split == "eval":
@ -100,7 +125,7 @@ def main():
supervision.audio_event = labels[cut_id] supervision.audio_event = labels[cut_id]
except KeyError: except KeyError:
logging.info(f"No labels found for {cut_id}.") logging.info(f"No labels found for {cut_id}.")
supervision.audio_event = "" continue
cut.supervisions = [supervision] cut.supervisions = [supervision]
new_cuts.append(cut) new_cuts.append(cut)
@ -115,7 +140,7 @@ def main():
with get_executor() as ex: with get_executor() as ex:
cuts = cuts.compute_and_store_features( cuts = cuts.compute_and_store_features(
extractor=extractor, extractor=extractor,
storage_path=f"{feat_output_dir}/{split}_{args.split}_feats", storage_path=f"{feat_output_dir}/{split}_{split}_feats",
num_jobs=num_jobs if ex is None else 80, num_jobs=num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,

View File

@ -56,7 +56,7 @@ class AudioSetATDatamodule:
DataModule for k2 audio tagging (AT) experiments. DataModule for k2 audio tagging (AT) experiments.
It contains all the common data pipeline modules used in ASR It contains all the common data pipeline modules used in AT
experiments, e.g.: experiments, e.g.:
- dynamic batch size, - dynamic batch size,
- bucketing samplers, - bucketing samplers,
@ -64,7 +64,7 @@ class AudioSetATDatamodule:
- augmentation, - augmentation,
- on-the-fly feature extraction - on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks. This class should be derived for specific corpora used in AT tasks.
""" """
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):