mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Parametrize dev and test split sizes.
This commit is contained in:
parent
a6f60de9dd
commit
c8d932b0c2
@ -10,12 +10,14 @@ def create_subset_by_hours(
|
||||
full_dataset_path,
|
||||
output_base_dir,
|
||||
target_train_hours,
|
||||
target_dev_hours, # New parameter
|
||||
target_test_hours, # New parameter
|
||||
random_seed=42,
|
||||
duration_column_name='audio_duration'
|
||||
):
|
||||
random.seed(random_seed)
|
||||
|
||||
output_subset_dir = os.path.join(output_base_dir, f'mls_english_subset_{int(target_train_hours)}h')
|
||||
output_subset_dir = os.path.join(output_base_dir, f'mls_english_subset_train{int(target_train_hours)}h_dev{int(target_dev_hours)}h_test{int(target_test_hours)}h')
|
||||
os.makedirs(output_subset_dir, exist_ok=True)
|
||||
output_subset_data_dir = os.path.join(output_subset_dir, 'data')
|
||||
os.makedirs(output_subset_data_dir, exist_ok=True)
|
||||
@ -35,7 +37,8 @@ def create_subset_by_hours(
|
||||
sys.exit(1)
|
||||
|
||||
data_files = {}
|
||||
split_pattern = re.compile(r'^(train|dev|test)-\d{5}-of-\d{5}\.parquet$')
|
||||
# Expanded pattern to also detect 'validation' if it's in filenames
|
||||
split_pattern = re.compile(r'^(train|dev|test|validation)-\d{5}-of-\d{5}\.parquet$')
|
||||
|
||||
print(f" Discovering splits from filenames in '{full_data_dir}'...")
|
||||
for fpath in all_parquet_files:
|
||||
@ -50,7 +53,7 @@ def create_subset_by_hours(
|
||||
print(f"Warning: Skipping unrecognized parquet file: {fname}", file=sys.stderr)
|
||||
|
||||
if not data_files:
|
||||
print("Error: No recognized train, dev, or test parquet files found.", file=sys.stderr)
|
||||
print("Error: No recognized train, dev, test, or validation parquet files found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found splits and their parquet files: {list(data_files.keys())}")
|
||||
@ -65,93 +68,107 @@ def create_subset_by_hours(
|
||||
print("Error: The loaded dataset is not a DatasetDict. Expected a DatasetDict structure.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# --- Renaming 'validation' split to 'dev' if necessary ---
|
||||
if 'validation' in full_dataset:
|
||||
if 'dev' in full_dataset:
|
||||
print("Warning: Both 'dev' and 'validation' splits found in the original dataset. Keeping 'dev' and skipping rename of 'validation'.", file=sys.stderr)
|
||||
else:
|
||||
print("Renaming 'validation' split to 'dev' for consistent keying.")
|
||||
full_dataset['dev'] = full_dataset.pop('validation')
|
||||
# --- End Renaming ---
|
||||
|
||||
subset_dataset = DatasetDict()
|
||||
total_final_duration_ms = 0
|
||||
|
||||
def get_duration_from_column(example):
|
||||
"""Helper to safely get duration from the specified column, in milliseconds."""
|
||||
if duration_column_name in example:
|
||||
return float(example[duration_column_name]) * 1000
|
||||
else:
|
||||
print(f"Warning: Duration column '{duration_column_name}' not found in example. Returning 0.", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
# --- Handle 'dev' split: Copy directly ---
|
||||
if 'dev' in full_dataset:
|
||||
dev_split = full_dataset['dev']
|
||||
subset_dataset['dev'] = dev_split
|
||||
dev_duration_ms = sum(get_duration_from_column(ex) for ex in dev_split)
|
||||
total_final_duration_ms += dev_duration_ms
|
||||
print(f"Copied 'dev' split directly: {len(dev_split)} samples ({dev_duration_ms / (3600*1000):.2f} hours)")
|
||||
else:
|
||||
print("Warning: 'dev' split not found in original dataset. Skipping copy.")
|
||||
# --- NEW: Generalized sampling function ---
|
||||
def sample_split_by_hours(split_name, original_split, target_hours):
|
||||
"""
|
||||
Samples a dataset split to reach a target number of hours.
|
||||
Returns the sampled Dataset object and its actual duration in milliseconds.
|
||||
"""
|
||||
target_duration_ms = target_hours * 3600 * 1000
|
||||
current_duration_ms = 0
|
||||
indices_to_include = []
|
||||
|
||||
# --- Handle 'test' split: Copy directly ---
|
||||
if 'test' in full_dataset:
|
||||
test_split = full_dataset['test']
|
||||
subset_dataset['test'] = test_split
|
||||
test_duration_ms = sum(get_duration_from_column(ex) for ex in test_split)
|
||||
total_final_duration_ms += test_duration_ms
|
||||
print(f"Copied 'test' split directly: {len(test_split)} samples ({test_duration_ms / (3600*1000):.2f} hours)")
|
||||
else:
|
||||
print("Warning: 'test' split not found in original dataset. Skipping copy.")
|
||||
if original_split is None or len(original_split) == 0:
|
||||
print(f" Warning: Original '{split_name}' split is empty or not found. Cannot sample.", file=sys.stderr)
|
||||
return None, 0
|
||||
|
||||
# --- Handle 'train' split: Sample by target hours (stream processing) ---
|
||||
target_train_duration_ms = target_train_hours * 3600 * 1000
|
||||
current_train_duration_ms = 0
|
||||
train_indices_to_include = [] # Store indices of selected samples
|
||||
print(f"\n Processing '{split_name}' split to reach approximately {target_hours} hours...")
|
||||
print(f" Total samples in original '{split_name}' split: {len(original_split)}")
|
||||
|
||||
if 'train' in full_dataset:
|
||||
train_split = full_dataset['train']
|
||||
print(f"\n Processing 'train' split to reach approximately {target_train_hours} hours...")
|
||||
|
||||
# Get total number of samples in the train split
|
||||
total_train_samples = len(train_split)
|
||||
print(f" Total samples in original train split: {total_train_samples}")
|
||||
|
||||
# Create a list of all indices in the train split
|
||||
all_train_indices = list(range(total_train_samples))
|
||||
random.shuffle(all_train_indices) # Shuffle the indices
|
||||
all_original_indices = list(range(len(original_split)))
|
||||
random.shuffle(all_original_indices) # Shuffle indices for random sampling
|
||||
|
||||
num_samples_processed = 0
|
||||
for original_idx in all_train_indices:
|
||||
if current_train_duration_ms >= target_train_duration_ms:
|
||||
print(f" Target train hours reached. Stopping processing.")
|
||||
break # Target train hours reached, stop adding samples
|
||||
for original_idx in all_original_indices:
|
||||
if current_duration_ms >= target_duration_ms and target_hours > 0:
|
||||
print(f" Target {split_name} hours reached ({target_hours}h). Stopping processing.")
|
||||
break
|
||||
|
||||
example = train_split[original_idx] # Access sample by original index
|
||||
example = original_split[original_idx]
|
||||
duration_ms = get_duration_from_column(example)
|
||||
|
||||
if duration_ms > 0:
|
||||
train_indices_to_include.append(original_idx)
|
||||
current_train_duration_ms += duration_ms
|
||||
|
||||
indices_to_include.append(original_idx)
|
||||
current_duration_ms += duration_ms
|
||||
|
||||
num_samples_processed += 1
|
||||
if num_samples_processed % 10000 == 0:
|
||||
print(f" Processed {num_samples_processed} samples. Current train duration: {current_train_duration_ms / (3600*1000):.2f} hours")
|
||||
|
||||
|
||||
# Select the subset from the original split based on chosen indices
|
||||
# Sorting is important here to ensure the resulting subset maintains the original order,
|
||||
# which can be useful for debugging or consistent processing down the line.
|
||||
selected_indices = sorted(train_indices_to_include)
|
||||
subset_train_split = train_split.select(selected_indices)
|
||||
if num_samples_processed % 10000 == 0: # Print progress periodically
|
||||
print(f" Processed {num_samples_processed} samples for '{split_name}'. Current duration: {current_duration_ms / (3600*1000):.2f} hours")
|
||||
|
||||
# If target_hours was 0, but there were samples, we should include none.
|
||||
# Otherwise, select the chosen indices.
|
||||
if target_hours == 0:
|
||||
sampled_split = original_split.select([]) # Select an empty dataset
|
||||
else:
|
||||
sampled_split = original_split.select(sorted(indices_to_include)) # Sort to preserve order
|
||||
|
||||
# Ensure the 'audio' column is correctly typed as Audio feature before saving
|
||||
if "audio" in subset_train_split.features and not isinstance(subset_train_split.features["audio"], Audio):
|
||||
sampling_rate = subset_train_split.features["audio"].sampling_rate if isinstance(subset_train_split.features["audio"], Audio) else 16000
|
||||
new_features = subset_train_split.features
|
||||
if "audio" in sampled_split.features and not isinstance(sampled_split.features["audio"], Audio):
|
||||
sampling_rate = sampled_split.features["audio"].sampling_rate if isinstance(sampled_split.features["audio"], Audio) else 16000
|
||||
new_features = sampled_split.features
|
||||
new_features["audio"] = Audio(sampling_rate=sampling_rate)
|
||||
subset_train_split = subset_train_split.cast(new_features)
|
||||
sampled_split = sampled_split.cast(new_features)
|
||||
|
||||
print(f" Final '{split_name}' split duration: {current_duration_ms / (3600*1000):.2f} hours ({len(sampled_split)} samples)")
|
||||
return sampled_split, current_duration_ms
|
||||
# --- END NEW: Generalized sampling function ---
|
||||
|
||||
# --- Apply sampling for train, dev, and test splits ---
|
||||
splits_to_process = {
|
||||
'train': target_train_hours,
|
||||
'dev': target_dev_hours,
|
||||
'test': target_test_hours
|
||||
}
|
||||
|
||||
for split_name, target_hours in splits_to_process.items():
|
||||
if split_name in full_dataset:
|
||||
original_split = full_dataset[split_name]
|
||||
sampled_split, actual_duration_ms = sample_split_by_hours(
|
||||
split_name,
|
||||
original_split,
|
||||
target_hours
|
||||
)
|
||||
if sampled_split is not None:
|
||||
subset_dataset[split_name] = sampled_split
|
||||
total_final_duration_ms += actual_duration_ms
|
||||
else:
|
||||
print(f"Warning: '{split_name}' split not found in original dataset. Skipping sampling.", file=sys.stderr)
|
||||
|
||||
subset_dataset['train'] = subset_train_split
|
||||
total_final_duration_ms += current_train_duration_ms
|
||||
print(f" Final 'train' split duration: {current_train_duration_ms / (3600*1000):.2f} hours ({len(subset_train_split)} samples)")
|
||||
else:
|
||||
print("Warning: 'train' split not found in original dataset. No training data will be created.")
|
||||
|
||||
# --- Handle other splits if any, just copy them ---
|
||||
# This loop now excludes 'validation' since it's handled by renaming to 'dev'
|
||||
for split_name in full_dataset.keys():
|
||||
if split_name not in ['train', 'dev', 'test']:
|
||||
if split_name not in ['train', 'dev', 'test', 'validation']: # Ensure 'validation' is not re-copied if not renamed
|
||||
print(f"Copying unrecognized split '{split_name}' directly.")
|
||||
other_split = full_dataset[split_name]
|
||||
subset_dataset[split_name] = other_split
|
||||
@ -177,8 +194,8 @@ def create_subset_by_hours(
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create a smaller subset of a downloaded Hugging Face audio dataset. "
|
||||
"Copies 'dev' and 'test' splits, and samples 'train' split to a target duration, "
|
||||
"using pre-existing duration column."
|
||||
"Samples train, dev, and test splits to target durations using pre-existing duration column. "
|
||||
"Ensures 'validation' split is renamed to 'dev'."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-dataset-path",
|
||||
@ -193,7 +210,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="The base directory where the new subset dataset(s) will be saved. "
|
||||
"A subdirectory 'mls_english_subset_Xh' will be created within it."
|
||||
"A subdirectory 'mls_english_subset_trainXh_devYh_testZh' will be created within it."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-train-hours",
|
||||
@ -201,6 +218,18 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="The approximate total duration of the 'train' split in hours (e.g., 1000 for 1000 hours)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-dev-hours",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The approximate total duration of the 'dev' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-test-hours",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The approximate total duration of the 'test' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
@ -219,9 +248,23 @@ if __name__ == "__main__":
|
||||
args.full_dataset_path,
|
||||
args.output_base_dir,
|
||||
args.target_train_hours,
|
||||
args.target_dev_hours,
|
||||
args.target_test_hours,
|
||||
args.random_seed,
|
||||
args.duration_column_name
|
||||
)
|
||||
|
||||
output_subset_full_path = os.path.join(args.output_base_dir, f'mls_english_subset_{int(args.target_train_hours)}h')
|
||||
output_subset_data_path = os.path.join(output_subset_full_path, 'data')
|
||||
# Simplified load path message for clarity
|
||||
output_subset_full_path_name = f'mls_english_subset_train{int(args.target_train_hours)}h_dev{int(args.target_dev_hours)}h_test{int(args.target_test_hours)}h'
|
||||
output_subset_data_path = os.path.join(args.output_base_dir, output_subset_full_path_name, 'data')
|
||||
|
||||
print(f"\nTo use your new subset dataset, you can load it like this:")
|
||||
print(f"from datasets import load_dataset")
|
||||
print(f"import os, glob")
|
||||
print(f"data_files = {{}}")
|
||||
print(f"for split_name in ['train', 'dev', 'test']: # Or iterate through actual splits created")
|
||||
print(f" split_path = os.path.join('{output_subset_data_path}', f'{{split_name}}*.parquet')")
|
||||
print(f" files = glob.glob(split_path)")
|
||||
print(f" if files: data_files[split_name] = files")
|
||||
print(f"subset = load_dataset('parquet', data_files=data_files)")
|
||||
print(f"print(subset)")
|
Loading…
x
Reference in New Issue
Block a user