From c8d932b0c2100c45bbe2bb88138724dd4661922d Mon Sep 17 00:00:00 2001 From: Kinan Martin Date: Tue, 10 Jun 2025 10:11:33 +0900 Subject: [PATCH] Parametrize dev and test split sizes. --- .../ASR/local/utils/create_subsets_greedy.py | 179 +++++++++++------- 1 file changed, 111 insertions(+), 68 deletions(-) diff --git a/egs/mls_english/ASR/local/utils/create_subsets_greedy.py b/egs/mls_english/ASR/local/utils/create_subsets_greedy.py index 84982af11..c31c96d51 100644 --- a/egs/mls_english/ASR/local/utils/create_subsets_greedy.py +++ b/egs/mls_english/ASR/local/utils/create_subsets_greedy.py @@ -10,12 +10,14 @@ def create_subset_by_hours( full_dataset_path, output_base_dir, target_train_hours, + target_dev_hours, # New parameter + target_test_hours, # New parameter random_seed=42, duration_column_name='audio_duration' ): random.seed(random_seed) - output_subset_dir = os.path.join(output_base_dir, f'mls_english_subset_{int(target_train_hours)}h') + output_subset_dir = os.path.join(output_base_dir, f'mls_english_subset_train{int(target_train_hours)}h_dev{int(target_dev_hours)}h_test{int(target_test_hours)}h') os.makedirs(output_subset_dir, exist_ok=True) output_subset_data_dir = os.path.join(output_subset_dir, 'data') os.makedirs(output_subset_data_dir, exist_ok=True) @@ -35,7 +37,8 @@ def create_subset_by_hours( sys.exit(1) data_files = {} - split_pattern = re.compile(r'^(train|dev|test)-\d{5}-of-\d{5}\.parquet$') + # Expanded pattern to also detect 'validation' if it's in filenames + split_pattern = re.compile(r'^(train|dev|test|validation)-\d{5}-of-\d{5}\.parquet$') print(f" Discovering splits from filenames in '{full_data_dir}'...") for fpath in all_parquet_files: @@ -50,7 +53,7 @@ def create_subset_by_hours( print(f"Warning: Skipping unrecognized parquet file: {fname}", file=sys.stderr) if not data_files: - print("Error: No recognized train, dev, or test parquet files found.", file=sys.stderr) + print("Error: No recognized train, dev, test, or validation parquet files found.", file=sys.stderr) sys.exit(1) print(f"Found splits and their parquet files: {list(data_files.keys())}") @@ -65,93 +68,107 @@ def create_subset_by_hours( print("Error: The loaded dataset is not a DatasetDict. Expected a DatasetDict structure.", file=sys.stderr) sys.exit(1) + # --- Renaming 'validation' split to 'dev' if necessary --- + if 'validation' in full_dataset: + if 'dev' in full_dataset: + print("Warning: Both 'dev' and 'validation' splits found in the original dataset. Keeping 'dev' and skipping rename of 'validation'.", file=sys.stderr) + else: + print("Renaming 'validation' split to 'dev' for consistent keying.") + full_dataset['dev'] = full_dataset.pop('validation') + # --- End Renaming --- + subset_dataset = DatasetDict() total_final_duration_ms = 0 def get_duration_from_column(example): + """Helper to safely get duration from the specified column, in milliseconds.""" if duration_column_name in example: return float(example[duration_column_name]) * 1000 else: print(f"Warning: Duration column '{duration_column_name}' not found in example. Returning 0.", file=sys.stderr) return 0 - # --- Handle 'dev' split: Copy directly --- - if 'dev' in full_dataset: - dev_split = full_dataset['dev'] - subset_dataset['dev'] = dev_split - dev_duration_ms = sum(get_duration_from_column(ex) for ex in dev_split) - total_final_duration_ms += dev_duration_ms - print(f"Copied 'dev' split directly: {len(dev_split)} samples ({dev_duration_ms / (3600*1000):.2f} hours)") - else: - print("Warning: 'dev' split not found in original dataset. Skipping copy.") + # --- NEW: Generalized sampling function --- + def sample_split_by_hours(split_name, original_split, target_hours): + """ + Samples a dataset split to reach a target number of hours. + Returns the sampled Dataset object and its actual duration in milliseconds. + """ + target_duration_ms = target_hours * 3600 * 1000 + current_duration_ms = 0 + indices_to_include = [] - # --- Handle 'test' split: Copy directly --- - if 'test' in full_dataset: - test_split = full_dataset['test'] - subset_dataset['test'] = test_split - test_duration_ms = sum(get_duration_from_column(ex) for ex in test_split) - total_final_duration_ms += test_duration_ms - print(f"Copied 'test' split directly: {len(test_split)} samples ({test_duration_ms / (3600*1000):.2f} hours)") - else: - print("Warning: 'test' split not found in original dataset. Skipping copy.") + if original_split is None or len(original_split) == 0: + print(f" Warning: Original '{split_name}' split is empty or not found. Cannot sample.", file=sys.stderr) + return None, 0 - # --- Handle 'train' split: Sample by target hours (stream processing) --- - target_train_duration_ms = target_train_hours * 3600 * 1000 - current_train_duration_ms = 0 - train_indices_to_include = [] # Store indices of selected samples + print(f"\n Processing '{split_name}' split to reach approximately {target_hours} hours...") + print(f" Total samples in original '{split_name}' split: {len(original_split)}") - if 'train' in full_dataset: - train_split = full_dataset['train'] - print(f"\n Processing 'train' split to reach approximately {target_train_hours} hours...") - - # Get total number of samples in the train split - total_train_samples = len(train_split) - print(f" Total samples in original train split: {total_train_samples}") - - # Create a list of all indices in the train split - all_train_indices = list(range(total_train_samples)) - random.shuffle(all_train_indices) # Shuffle the indices + all_original_indices = list(range(len(original_split))) + random.shuffle(all_original_indices) # Shuffle indices for random sampling num_samples_processed = 0 - for original_idx in all_train_indices: - if current_train_duration_ms >= target_train_duration_ms: - print(f" Target train hours reached. Stopping processing.") - break # Target train hours reached, stop adding samples + for original_idx in all_original_indices: + if current_duration_ms >= target_duration_ms and target_hours > 0: + print(f" Target {split_name} hours reached ({target_hours}h). Stopping processing.") + break - example = train_split[original_idx] # Access sample by original index + example = original_split[original_idx] duration_ms = get_duration_from_column(example) if duration_ms > 0: - train_indices_to_include.append(original_idx) - current_train_duration_ms += duration_ms - + indices_to_include.append(original_idx) + current_duration_ms += duration_ms + num_samples_processed += 1 - if num_samples_processed % 10000 == 0: - print(f" Processed {num_samples_processed} samples. Current train duration: {current_train_duration_ms / (3600*1000):.2f} hours") - - - # Select the subset from the original split based on chosen indices - # Sorting is important here to ensure the resulting subset maintains the original order, - # which can be useful for debugging or consistent processing down the line. - selected_indices = sorted(train_indices_to_include) - subset_train_split = train_split.select(selected_indices) + if num_samples_processed % 10000 == 0: # Print progress periodically + print(f" Processed {num_samples_processed} samples for '{split_name}'. Current duration: {current_duration_ms / (3600*1000):.2f} hours") + + # If target_hours was 0, but there were samples, we should include none. + # Otherwise, select the chosen indices. + if target_hours == 0: + sampled_split = original_split.select([]) # Select an empty dataset + else: + sampled_split = original_split.select(sorted(indices_to_include)) # Sort to preserve order # Ensure the 'audio' column is correctly typed as Audio feature before saving - if "audio" in subset_train_split.features and not isinstance(subset_train_split.features["audio"], Audio): - sampling_rate = subset_train_split.features["audio"].sampling_rate if isinstance(subset_train_split.features["audio"], Audio) else 16000 - new_features = subset_train_split.features + if "audio" in sampled_split.features and not isinstance(sampled_split.features["audio"], Audio): + sampling_rate = sampled_split.features["audio"].sampling_rate if isinstance(sampled_split.features["audio"], Audio) else 16000 + new_features = sampled_split.features new_features["audio"] = Audio(sampling_rate=sampling_rate) - subset_train_split = subset_train_split.cast(new_features) + sampled_split = sampled_split.cast(new_features) + + print(f" Final '{split_name}' split duration: {current_duration_ms / (3600*1000):.2f} hours ({len(sampled_split)} samples)") + return sampled_split, current_duration_ms + # --- END NEW: Generalized sampling function --- + + # --- Apply sampling for train, dev, and test splits --- + splits_to_process = { + 'train': target_train_hours, + 'dev': target_dev_hours, + 'test': target_test_hours + } + + for split_name, target_hours in splits_to_process.items(): + if split_name in full_dataset: + original_split = full_dataset[split_name] + sampled_split, actual_duration_ms = sample_split_by_hours( + split_name, + original_split, + target_hours + ) + if sampled_split is not None: + subset_dataset[split_name] = sampled_split + total_final_duration_ms += actual_duration_ms + else: + print(f"Warning: '{split_name}' split not found in original dataset. Skipping sampling.", file=sys.stderr) - subset_dataset['train'] = subset_train_split - total_final_duration_ms += current_train_duration_ms - print(f" Final 'train' split duration: {current_train_duration_ms / (3600*1000):.2f} hours ({len(subset_train_split)} samples)") - else: - print("Warning: 'train' split not found in original dataset. No training data will be created.") # --- Handle other splits if any, just copy them --- + # This loop now excludes 'validation' since it's handled by renaming to 'dev' for split_name in full_dataset.keys(): - if split_name not in ['train', 'dev', 'test']: + if split_name not in ['train', 'dev', 'test', 'validation']: # Ensure 'validation' is not re-copied if not renamed print(f"Copying unrecognized split '{split_name}' directly.") other_split = full_dataset[split_name] subset_dataset[split_name] = other_split @@ -177,8 +194,8 @@ def create_subset_by_hours( if __name__ == "__main__": parser = argparse.ArgumentParser( description="Create a smaller subset of a downloaded Hugging Face audio dataset. " - "Copies 'dev' and 'test' splits, and samples 'train' split to a target duration, " - "using pre-existing duration column." + "Samples train, dev, and test splits to target durations using pre-existing duration column. " + "Ensures 'validation' split is renamed to 'dev'." ) parser.add_argument( "--full-dataset-path", @@ -193,7 +210,7 @@ if __name__ == "__main__": type=str, required=True, help="The base directory where the new subset dataset(s) will be saved. " - "A subdirectory 'mls_english_subset_Xh' will be created within it." + "A subdirectory 'mls_english_subset_trainXh_devYh_testZh' will be created within it." ) parser.add_argument( "--target-train-hours", @@ -201,6 +218,18 @@ if __name__ == "__main__": required=True, help="The approximate total duration of the 'train' split in hours (e.g., 1000 for 1000 hours)." ) + parser.add_argument( + "--target-dev-hours", + type=float, + default=0.0, + help="The approximate total duration of the 'dev' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split." + ) + parser.add_argument( + "--target-test-hours", + type=float, + default=0.0, + help="The approximate total duration of the 'test' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split." + ) parser.add_argument( "--random-seed", type=int, @@ -219,9 +248,23 @@ if __name__ == "__main__": args.full_dataset_path, args.output_base_dir, args.target_train_hours, + args.target_dev_hours, + args.target_test_hours, args.random_seed, args.duration_column_name ) - output_subset_full_path = os.path.join(args.output_base_dir, f'mls_english_subset_{int(args.target_train_hours)}h') - output_subset_data_path = os.path.join(output_subset_full_path, 'data') \ No newline at end of file + # Simplified load path message for clarity + output_subset_full_path_name = f'mls_english_subset_train{int(args.target_train_hours)}h_dev{int(args.target_dev_hours)}h_test{int(args.target_test_hours)}h' + output_subset_data_path = os.path.join(args.output_base_dir, output_subset_full_path_name, 'data') + + print(f"\nTo use your new subset dataset, you can load it like this:") + print(f"from datasets import load_dataset") + print(f"import os, glob") + print(f"data_files = {{}}") + print(f"for split_name in ['train', 'dev', 'test']: # Or iterate through actual splits created") + print(f" split_path = os.path.join('{output_subset_data_path}', f'{{split_name}}*.parquet')") + print(f" files = glob.glob(split_path)") + print(f" if files: data_files[split_name] = files") + print(f"subset = load_dataset('parquet', data_files=data_files)") + print(f"print(subset)") \ No newline at end of file