This commit is contained in:
Bailey Hirota 2025-09-02 10:36:51 +09:00
parent 7231cf44aa
commit a4c1db5a49
3 changed files with 160 additions and 78 deletions

View File

@ -76,11 +76,19 @@ def make_cutset_blueprints(
logging.info("Creating dev cuts.") logging.info("Creating dev cuts.")
try: try:
cut_sets.append( cut_sets.append(
("dev", CutSet.from_huggingface_dataset(dataset["dev"], text_key="transcript")) (
"dev",
CutSet.from_huggingface_dataset(dataset["dev"], text_key="transcript"),
)
) )
except KeyError: except KeyError:
cut_sets.append( cut_sets.append(
("dev", CutSet.from_huggingface_dataset(dataset["validation"], text_key="transcript")) (
"dev",
CutSet.from_huggingface_dataset(
dataset["validation"], text_key="transcript"
),
)
) )
# Create train dataset # Create train dataset
@ -121,15 +129,15 @@ def main():
) )
return return
else: else:
mls_eng_hf_dataset_path = args.dl_dir # "/root/datasets/parler-tts--mls_eng" mls_eng_hf_dataset_path = args.dl_dir # "/root/datasets/parler-tts--mls_eng"
cut_sets = make_cutset_blueprints(mls_eng_hf_dataset_path) cut_sets = make_cutset_blueprints(mls_eng_hf_dataset_path)
for part, cut_set in cut_sets: for part, cut_set in cut_sets:
logging.info(f"Processing {part}") logging.info(f"Processing {part}")
cut_set = cut_set.save_audios( cut_set = cut_set.save_audios(
num_jobs=num_jobs, num_jobs=num_jobs,
storage_path=(args.audio_dir / part).as_posix(), storage_path=(args.audio_dir / part).as_posix(),
) # makes new cutset that loads audio from paths to actual audio files ) # makes new cutset that loads audio from paths to actual audio files
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
num_jobs=num_jobs, num_jobs=num_jobs,

View File

@ -1,44 +1,54 @@
import argparse import argparse
import os
import sys
from datasets import load_dataset, DatasetDict, Audio
import random
import glob import glob
import os
import random
import re import re
import sys
from datasets import Audio, DatasetDict, load_dataset
def create_subset_by_hours( def create_subset_by_hours(
full_dataset_path, full_dataset_path,
output_base_dir, output_base_dir,
target_train_hours, target_train_hours,
target_dev_hours, # New parameter target_dev_hours, # New parameter
target_test_hours, # New parameter target_test_hours, # New parameter
random_seed=42, random_seed=42,
duration_column_name='audio_duration' duration_column_name="audio_duration",
): ):
random.seed(random_seed) random.seed(random_seed)
output_subset_dir = os.path.join(output_base_dir, f'mls_english_subset_train{int(target_train_hours)}h_dev{int(target_dev_hours)}h_test{int(target_test_hours)}h') output_subset_dir = os.path.join(
output_base_dir,
f"mls_english_subset_train{int(target_train_hours)}h_dev{int(target_dev_hours)}h_test{int(target_test_hours)}h",
)
os.makedirs(output_subset_dir, exist_ok=True) os.makedirs(output_subset_dir, exist_ok=True)
output_subset_data_dir = os.path.join(output_subset_dir, 'data') output_subset_data_dir = os.path.join(output_subset_dir, "data")
os.makedirs(output_subset_data_dir, exist_ok=True) os.makedirs(output_subset_data_dir, exist_ok=True)
print(f"Attempting to load full dataset from '{full_dataset_path}' using load_dataset...") print(
f"Attempting to load full dataset from '{full_dataset_path}' using load_dataset..."
)
full_data_dir = os.path.join(full_dataset_path, 'data') full_data_dir = os.path.join(full_dataset_path, "data")
if not os.path.isdir(full_data_dir): if not os.path.isdir(full_data_dir):
print(f"Error: Expected a 'data' subdirectory at '{full_data_dir}' containing parquet files. " print(
"Please ensure 'full_dataset_path' points to the root of your MLS English download " f"Error: Expected a 'data' subdirectory at '{full_data_dir}' containing parquet files. "
"(e.g., /path/to/mls_english_downloaded_dir) where 'data' is a direct child.", file=sys.stderr) "Please ensure 'full_dataset_path' points to the root of your MLS English download "
"(e.g., /path/to/mls_english_downloaded_dir) where 'data' is a direct child.",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
all_parquet_files = glob.glob(os.path.join(full_data_dir, '*.parquet')) all_parquet_files = glob.glob(os.path.join(full_data_dir, "*.parquet"))
if not all_parquet_files: if not all_parquet_files:
print(f"Error: No parquet files found in '{full_data_dir}'.", file=sys.stderr) print(f"Error: No parquet files found in '{full_data_dir}'.", file=sys.stderr)
sys.exit(1) sys.exit(1)
data_files = {} data_files = {}
# Expanded pattern to also detect 'validation' if it's in filenames # Expanded pattern to also detect 'validation' if it's in filenames
split_pattern = re.compile(r'^(train|dev|test|validation)-\d{5}-of-\d{5}\.parquet$') split_pattern = re.compile(r"^(train|dev|test|validation)-\d{5}-of-\d{5}\.parquet$")
print(f" Discovering splits from filenames in '{full_data_dir}'...") print(f" Discovering splits from filenames in '{full_data_dir}'...")
for fpath in all_parquet_files: for fpath in all_parquet_files:
@ -50,10 +60,15 @@ def create_subset_by_hours(
data_files[split_name] = [] data_files[split_name] = []
data_files[split_name].append(fpath) data_files[split_name].append(fpath)
else: else:
print(f"Warning: Skipping unrecognized parquet file: {fname}", file=sys.stderr) print(
f"Warning: Skipping unrecognized parquet file: {fname}", file=sys.stderr
)
if not data_files: if not data_files:
print("Error: No recognized train, dev, test, or validation parquet files found.", file=sys.stderr) print(
"Error: No recognized train, dev, test, or validation parquet files found.",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
print(f"Found splits and their parquet files: {list(data_files.keys())}") print(f"Found splits and their parquet files: {list(data_files.keys())}")
@ -61,20 +76,29 @@ def create_subset_by_hours(
try: try:
full_dataset = load_dataset("parquet", data_files=data_files) full_dataset = load_dataset("parquet", data_files=data_files)
except Exception as e: except Exception as e:
print(f"Error loading dataset from '{full_data_dir}' with load_dataset: {e}", file=sys.stderr) print(
f"Error loading dataset from '{full_data_dir}' with load_dataset: {e}",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
if not isinstance(full_dataset, DatasetDict): if not isinstance(full_dataset, DatasetDict):
print("Error: The loaded dataset is not a DatasetDict. Expected a DatasetDict structure.", file=sys.stderr) print(
"Error: The loaded dataset is not a DatasetDict. Expected a DatasetDict structure.",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
# --- Renaming 'validation' split to 'dev' if necessary --- # --- Renaming 'validation' split to 'dev' if necessary ---
if 'validation' in full_dataset: if "validation" in full_dataset:
if 'dev' in full_dataset: if "dev" in full_dataset:
print("Warning: Both 'dev' and 'validation' splits found in the original dataset. Keeping 'dev' and skipping rename of 'validation'.", file=sys.stderr) print(
"Warning: Both 'dev' and 'validation' splits found in the original dataset. Keeping 'dev' and skipping rename of 'validation'.",
file=sys.stderr,
)
else: else:
print("Renaming 'validation' split to 'dev' for consistent keying.") print("Renaming 'validation' split to 'dev' for consistent keying.")
full_dataset['dev'] = full_dataset.pop('validation') full_dataset["dev"] = full_dataset.pop("validation")
# --- End Renaming --- # --- End Renaming ---
subset_dataset = DatasetDict() subset_dataset = DatasetDict()
@ -85,7 +109,10 @@ def create_subset_by_hours(
if duration_column_name in example: if duration_column_name in example:
return float(example[duration_column_name]) * 1000 return float(example[duration_column_name]) * 1000
else: else:
print(f"Warning: Duration column '{duration_column_name}' not found in example. Returning 0.", file=sys.stderr) print(
f"Warning: Duration column '{duration_column_name}' not found in example. Returning 0.",
file=sys.stderr,
)
return 0 return 0
# --- NEW: Generalized sampling function --- # --- NEW: Generalized sampling function ---
@ -99,19 +126,28 @@ def create_subset_by_hours(
indices_to_include = [] indices_to_include = []
if original_split is None or len(original_split) == 0: if original_split is None or len(original_split) == 0:
print(f" Warning: Original '{split_name}' split is empty or not found. Cannot sample.", file=sys.stderr) print(
f" Warning: Original '{split_name}' split is empty or not found. Cannot sample.",
file=sys.stderr,
)
return None, 0 return None, 0
print(f"\n Processing '{split_name}' split to reach approximately {target_hours} hours...") print(
print(f" Total samples in original '{split_name}' split: {len(original_split)}") f"\n Processing '{split_name}' split to reach approximately {target_hours} hours..."
)
print(
f" Total samples in original '{split_name}' split: {len(original_split)}"
)
all_original_indices = list(range(len(original_split))) all_original_indices = list(range(len(original_split)))
random.shuffle(all_original_indices) # Shuffle indices for random sampling random.shuffle(all_original_indices) # Shuffle indices for random sampling
num_samples_processed = 0 num_samples_processed = 0
for original_idx in all_original_indices: for original_idx in all_original_indices:
if current_duration_ms >= target_duration_ms and target_hours > 0: if current_duration_ms >= target_duration_ms and target_hours > 0:
print(f" Target {split_name} hours reached ({target_hours}h). Stopping processing.") print(
f" Target {split_name} hours reached ({target_hours}h). Stopping processing."
)
break break
example = original_split[original_idx] example = original_split[original_idx]
@ -120,127 +156,156 @@ def create_subset_by_hours(
if duration_ms > 0: if duration_ms > 0:
indices_to_include.append(original_idx) indices_to_include.append(original_idx)
current_duration_ms += duration_ms current_duration_ms += duration_ms
num_samples_processed += 1 num_samples_processed += 1
if num_samples_processed % 10000 == 0: # Print progress periodically if num_samples_processed % 10000 == 0: # Print progress periodically
print(f" Processed {num_samples_processed} samples for '{split_name}'. Current duration: {current_duration_ms / (3600*1000):.2f} hours") print(
f" Processed {num_samples_processed} samples for '{split_name}'. Current duration: {current_duration_ms / (3600*1000):.2f} hours"
)
# If target_hours was 0, but there were samples, we should include none. # If target_hours was 0, but there were samples, we should include none.
# Otherwise, select the chosen indices. # Otherwise, select the chosen indices.
if target_hours == 0: if target_hours == 0:
sampled_split = original_split.select([]) # Select an empty dataset sampled_split = original_split.select([]) # Select an empty dataset
else: else:
sampled_split = original_split.select(sorted(indices_to_include)) # Sort to preserve order sampled_split = original_split.select(
sorted(indices_to_include)
) # Sort to preserve order
# Ensure the 'audio' column is correctly typed as Audio feature before saving # Ensure the 'audio' column is correctly typed as Audio feature before saving
if "audio" in sampled_split.features and not isinstance(sampled_split.features["audio"], Audio): if "audio" in sampled_split.features and not isinstance(
sampling_rate = sampled_split.features["audio"].sampling_rate if isinstance(sampled_split.features["audio"], Audio) else 16000 sampled_split.features["audio"], Audio
):
sampling_rate = (
sampled_split.features["audio"].sampling_rate
if isinstance(sampled_split.features["audio"], Audio)
else 16000
)
new_features = sampled_split.features new_features = sampled_split.features
new_features["audio"] = Audio(sampling_rate=sampling_rate) new_features["audio"] = Audio(sampling_rate=sampling_rate)
sampled_split = sampled_split.cast(new_features) sampled_split = sampled_split.cast(new_features)
print(f" Final '{split_name}' split duration: {current_duration_ms / (3600*1000):.2f} hours ({len(sampled_split)} samples)") print(
f" Final '{split_name}' split duration: {current_duration_ms / (3600*1000):.2f} hours ({len(sampled_split)} samples)"
)
return sampled_split, current_duration_ms return sampled_split, current_duration_ms
# --- END NEW: Generalized sampling function --- # --- END NEW: Generalized sampling function ---
# --- Apply sampling for train, dev, and test splits --- # --- Apply sampling for train, dev, and test splits ---
splits_to_process = { splits_to_process = {
'train': target_train_hours, "train": target_train_hours,
'dev': target_dev_hours, "dev": target_dev_hours,
'test': target_test_hours "test": target_test_hours,
} }
for split_name, target_hours in splits_to_process.items(): for split_name, target_hours in splits_to_process.items():
if split_name in full_dataset: if split_name in full_dataset:
original_split = full_dataset[split_name] original_split = full_dataset[split_name]
sampled_split, actual_duration_ms = sample_split_by_hours( sampled_split, actual_duration_ms = sample_split_by_hours(
split_name, split_name, original_split, target_hours
original_split,
target_hours
) )
if sampled_split is not None: if sampled_split is not None:
subset_dataset[split_name] = sampled_split subset_dataset[split_name] = sampled_split
total_final_duration_ms += actual_duration_ms total_final_duration_ms += actual_duration_ms
else: else:
print(f"Warning: '{split_name}' split not found in original dataset. Skipping sampling.", file=sys.stderr) print(
f"Warning: '{split_name}' split not found in original dataset. Skipping sampling.",
file=sys.stderr,
)
# --- Handle other splits if any, just copy them --- # --- Handle other splits if any, just copy them ---
# This loop now excludes 'validation' since it's handled by renaming to 'dev' # This loop now excludes 'validation' since it's handled by renaming to 'dev'
for split_name in full_dataset.keys(): for split_name in full_dataset.keys():
if split_name not in ['train', 'dev', 'test', 'validation']: # Ensure 'validation' is not re-copied if not renamed if split_name not in [
"train",
"dev",
"test",
"validation",
]: # Ensure 'validation' is not re-copied if not renamed
print(f"Copying unrecognized split '{split_name}' directly.") print(f"Copying unrecognized split '{split_name}' directly.")
other_split = full_dataset[split_name] other_split = full_dataset[split_name]
subset_dataset[split_name] = other_split subset_dataset[split_name] = other_split
other_duration_ms = sum(get_duration_from_column(ex) for ex in other_split) other_duration_ms = sum(get_duration_from_column(ex) for ex in other_split)
total_final_duration_ms += other_duration_ms total_final_duration_ms += other_duration_ms
print(f" Copied '{split_name}' split: {len(other_split)} samples ({other_duration_ms / (3600*1000):.2f} hours)") print(
f" Copied '{split_name}' split: {len(other_split)} samples ({other_duration_ms / (3600*1000):.2f} hours)"
)
final_total_hours = total_final_duration_ms / (3600 * 1000) final_total_hours = total_final_duration_ms / (3600 * 1000)
print(f"\nOverall subset dataset duration (train + dev + test + others): {final_total_hours:.2f} hours") print(
f"\nOverall subset dataset duration (train + dev + test + others): {final_total_hours:.2f} hours"
)
print(f"Saving subset dataset to '{output_subset_dir}' in Parquet format, matching original 'data' structure...") print(
f"Saving subset dataset to '{output_subset_dir}' in Parquet format, matching original 'data' structure..."
)
try: try:
for split_name, ds_split in subset_dataset.items(): for split_name, ds_split in subset_dataset.items():
ds_split.to_parquet(os.path.join(output_subset_data_dir, f'{split_name}.parquet')) ds_split.to_parquet(
os.path.join(output_subset_data_dir, f"{split_name}.parquet")
)
print(f" Saved split '{split_name}' to '{output_subset_data_dir}'") print(f" Saved split '{split_name}' to '{output_subset_data_dir}'")
print(f"Successfully created and saved subset dataset to '{output_subset_dir}'") print(f"Successfully created and saved subset dataset to '{output_subset_dir}'")
except Exception as e: except Exception as e:
print(f"Error saving subset dataset to '{output_subset_dir}': {e}", file=sys.stderr) print(
f"Error saving subset dataset to '{output_subset_dir}': {e}",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Create a smaller subset of a downloaded Hugging Face audio dataset. " description="Create a smaller subset of a downloaded Hugging Face audio dataset. "
"Samples train, dev, and test splits to target durations using pre-existing duration column. " "Samples train, dev, and test splits to target durations using pre-existing duration column. "
"Ensures 'validation' split is renamed to 'dev'." "Ensures 'validation' split is renamed to 'dev'."
) )
parser.add_argument( parser.add_argument(
"--full-dataset-path", "--full-dataset-path",
type=str, type=str,
required=True, required=True,
help="The local path to the already downloaded Hugging Face dataset. " help="The local path to the already downloaded Hugging Face dataset. "
"This should be the root directory containing the 'data' subdirectory " "This should be the root directory containing the 'data' subdirectory "
"(e.g., /path/to/mls_english_download)." "(e.g., /path/to/mls_english_download).",
) )
parser.add_argument( parser.add_argument(
"--output-base-dir", "--output-base-dir",
type=str, type=str,
required=True, required=True,
help="The base directory where the new subset dataset(s) will be saved. " help="The base directory where the new subset dataset(s) will be saved. "
"A subdirectory 'mls_english_subset_trainXh_devYh_testZh' will be created within it." "A subdirectory 'mls_english_subset_trainXh_devYh_testZh' will be created within it.",
) )
parser.add_argument( parser.add_argument(
"--target-train-hours", "--target-train-hours",
type=float, type=float,
required=True, required=True,
help="The approximate total duration of the 'train' split in hours (e.g., 1000 for 1000 hours)." help="The approximate total duration of the 'train' split in hours (e.g., 1000 for 1000 hours).",
) )
parser.add_argument( parser.add_argument(
"--target-dev-hours", "--target-dev-hours",
type=float, type=float,
default=0.0, default=0.0,
help="The approximate total duration of the 'dev' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split." help="The approximate total duration of the 'dev' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split.",
) )
parser.add_argument( parser.add_argument(
"--target-test-hours", "--target-test-hours",
type=float, type=float,
default=0.0, default=0.0,
help="The approximate total duration of the 'test' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split." help="The approximate total duration of the 'test' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split.",
) )
parser.add_argument( parser.add_argument(
"--random-seed", "--random-seed",
type=int, type=int,
default=42, default=42,
help="Seed for random number generation to ensure reproducibility (default: 42)." help="Seed for random number generation to ensure reproducibility (default: 42).",
) )
parser.add_argument( parser.add_argument(
"--duration-column-name", "--duration-column-name",
type=str, type=str,
default='audio_duration', default="audio_duration",
help="The name of the column in the dataset that contains the audio duration (assumed to be in seconds). Default: 'audio_duration'." help="The name of the column in the dataset that contains the audio duration (assumed to be in seconds). Default: 'audio_duration'.",
) )
args = parser.parse_args() args = parser.parse_args()
@ -251,20 +316,26 @@ if __name__ == "__main__":
args.target_dev_hours, args.target_dev_hours,
args.target_test_hours, args.target_test_hours,
args.random_seed, args.random_seed,
args.duration_column_name args.duration_column_name,
) )
# Simplified load path message for clarity # Simplified load path message for clarity
output_subset_full_path_name = f'mls_english_subset_train{int(args.target_train_hours)}h_dev{int(args.target_dev_hours)}h_test{int(args.target_test_hours)}h' output_subset_full_path_name = f"mls_english_subset_train{int(args.target_train_hours)}h_dev{int(args.target_dev_hours)}h_test{int(args.target_test_hours)}h"
output_subset_data_path = os.path.join(args.output_base_dir, output_subset_full_path_name, 'data') output_subset_data_path = os.path.join(
args.output_base_dir, output_subset_full_path_name, "data"
)
print(f"\nTo use your new subset dataset, you can load it like this:") print(f"\nTo use your new subset dataset, you can load it like this:")
print(f"from datasets import load_dataset") print(f"from datasets import load_dataset")
print(f"import os, glob") print(f"import os, glob")
print(f"data_files = {{}}") print(f"data_files = {{}}")
print(f"for split_name in ['train', 'dev', 'test']: # Or iterate through actual splits created") print(
print(f" split_path = os.path.join('{output_subset_data_path}', f'{{split_name}}*.parquet')") f"for split_name in ['train', 'dev', 'test']: # Or iterate through actual splits created"
)
print(
f" split_path = os.path.join('{output_subset_data_path}', f'{{split_name}}*.parquet')"
)
print(f" files = glob.glob(split_path)") print(f" files = glob.glob(split_path)")
print(f" if files: data_files[split_name] = files") print(f" if files: data_files[split_name] = files")
print(f"subset = load_dataset('parquet', data_files=data_files)") print(f"subset = load_dataset('parquet', data_files=data_files)")
print(f"print(subset)") print(f"print(subset)")

View File

@ -1,14 +1,16 @@
import argparse import argparse
import os import os
import sys import sys
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
def download_dataset(dl_dir): def download_dataset(dl_dir):
""" """
Downloads the MLS English dataset from Hugging Face to `$dl_dir/mls_english`. Downloads the MLS English dataset from Hugging Face to `$dl_dir/mls_english`.
""" """
repo_id = 'parler-tts/mls_eng' repo_id = "parler-tts/mls_eng"
local_dataset_dir = os.path.join(dl_dir, 'mls_english') local_dataset_dir = os.path.join(dl_dir, "mls_english")
print(f"Attempting to download '{repo_id}' to '{local_dataset_dir}'...") print(f"Attempting to download '{repo_id}' to '{local_dataset_dir}'...")
@ -30,6 +32,7 @@ def download_dataset(dl_dir):
print(f"Error downloading dataset '{repo_id}': {e}", file=sys.stderr) print(f"Error downloading dataset '{repo_id}': {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Download MLS English dataset from Hugging Face." description="Download MLS English dataset from Hugging Face."
@ -42,4 +45,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
download_dataset(args.dl_dir) download_dataset(args.dl_dir)