black and isort formatting

This commit is contained in:
Bailey Hirota 2025-07-16 19:53:47 +09:00
parent 2f1f419149
commit 7b4abbaaac
5 changed files with 73 additions and 45 deletions

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import argparse
import inspect
import logging
@ -23,6 +22,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
@ -39,6 +39,7 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
@ -46,6 +47,7 @@ class _SeedWorkers:
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class MultiDatasetAsrDataModule:
"""
DataModule for k2 ASR experiments.
@ -202,8 +204,12 @@ class MultiDatasetAsrDataModule:
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan/musan_cuts.jsonl.gz")
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True))
cuts_musan = load_manifest(
self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -218,7 +224,8 @@ class MultiDatasetAsrDataModule:
)
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap)
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
if self.args.enable_spec_aug:

View File

@ -1,12 +1,13 @@
import logging
from pathlib import Path
import os # Import os module to handle symlinks
from pathlib import Path
from lhotse import CutSet, load_manifest
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
"""
Updates the storage_path in a CutSet's features to reflect the new dataset-specific
@ -29,7 +30,9 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
except ValueError:
# If for some reason the path doesn't start with old_feature_prefix,
# keep it as is. This can happen if some paths are already absolute or different.
logger.warning(f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut.")
logger.warning(
f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
)
updated_cuts.append(cut)
continue
@ -40,7 +43,9 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
else:
new_storage_path = Path("data/manifests") / dataset_name / relative_path
logger.info(f"Updating cut {cut.id}: {original_storage_path}{new_storage_path}")
logger.info(
f"Updating cut {cut.id}: {original_storage_path}{new_storage_path}"
)
cut.features.storage_path = str(new_storage_path)
updated_cuts.append(cut)
else:
@ -49,6 +54,7 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
return CutSet.from_cuts(updated_cuts)
if __name__ == "__main__":
# The root where the symlinked manifests are located in the multi_ja_en recipe
multi_recipe_manifests_root = Path("data/manifests")
@ -73,53 +79,66 @@ if __name__ == "__main__":
try:
musan_cuts = load_manifest(musan_manifest_path)
updated_musan_cuts = update_paths(
musan_cuts,
"musan",
old_feature_prefix="data/fbank"
musan_cuts, "musan", old_feature_prefix="data/fbank"
)
# Make sure we're overwriting the correct path even if it's a symlink
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
logger.info(f"Overwriting existing musan manifest at: {musan_manifest_path}")
logger.info(
f"Overwriting existing musan manifest at: {musan_manifest_path}"
)
os.unlink(musan_manifest_path)
updated_musan_cuts.to_file(musan_manifest_path)
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
except Exception as e:
logger.error(f"Error processing musan manifest {musan_manifest_path}: {e}", exc_info=True)
logger.error(
f"Error processing musan manifest {musan_manifest_path}: {e}",
exc_info=True,
)
else:
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
if not dataset_symlink_dir.is_dir():
logger.warning(f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}.")
logger.warning(
f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
)
continue
for split in splits:
# Construct the path to the symlinked manifest file
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
symlink_path = dataset_symlink_dir / manifest_filename # This is the path to the symlink itself
symlink_path = (
dataset_symlink_dir / manifest_filename
) # This is the path to the symlink itself
if symlink_path.is_symlink(): # Check if it's actually a symlink
# Get the actual path to the target file that the symlink points to
# Lhotse's load_manifest will follow this symlink automatically.
target_path = os.path.realpath(symlink_path)
logger.info(f"Processing symlink '{symlink_path}' pointing to '{target_path}'")
logger.info(
f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
)
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
logger.info(f"Processing regular file: {symlink_path}")
target_path = symlink_path # Use its own path as target
else:
logger.warning(f"Manifest file not found or neither a file nor a symlink: {symlink_path}")
logger.warning(
f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
)
continue # Skip to next iteration
try:
# Load the manifest. Lhotse will resolve the symlink internally for reading.
cuts = load_manifest(symlink_path) # Use symlink_path here, Lhotse handles resolution for loading
cuts = load_manifest(
symlink_path
) # Use symlink_path here, Lhotse handles resolution for loading
# Update the storage_path within the loaded cuts (in memory)
updated_cuts = update_paths(cuts, dataset_name, old_feature_prefix=original_feature_base_path)
updated_cuts = update_paths(
cuts, dataset_name, old_feature_prefix=original_feature_base_path
)
# --- CRITICAL CHANGE HERE ---
# Save the *modified* CutSet to the path of the symlink *itself*.
@ -127,7 +146,9 @@ if __name__ == "__main__":
# breaking the symlink and creating a new file in its place.
os.unlink(symlink_path)
updated_cuts.to_file(symlink_path)
logger.info(f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}")
logger.info(
f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
)
except Exception as e:
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)