black and isort formatting

This commit is contained in:
Bailey Hirota 2025-07-16 19:53:47 +09:00 committed by Kinan Martin
parent 154ef43206
commit 6012edbc17
5 changed files with 73 additions and 45 deletions

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import argparse import argparse
import inspect import inspect
import logging import logging
@ -23,6 +22,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import (
CutConcatenate, CutConcatenate,
@ -39,6 +39,7 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers: class _SeedWorkers:
def __init__(self, seed: int): def __init__(self, seed: int):
self.seed = seed self.seed = seed
@ -46,6 +47,7 @@ class _SeedWorkers:
def __call__(self, worker_id: int): def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id) fix_random_seed(self.seed + worker_id)
class MultiDatasetAsrDataModule: class MultiDatasetAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
@ -202,8 +204,12 @@ class MultiDatasetAsrDataModule:
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan/musan_cuts.jsonl.gz") cuts_musan = load_manifest(
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)) self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -218,7 +224,8 @@ class MultiDatasetAsrDataModule:
) )
transforms = [ transforms = [
CutConcatenate( CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap) duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms ] + transforms
if self.args.enable_spec_aug: if self.args.enable_spec_aug:

View File

@ -1,12 +1,13 @@
import logging import logging
import os # Import os module to handle symlinks
from pathlib import Path from pathlib import Path
import os # Import os module to handle symlinks
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet: def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
""" """
Updates the storage_path in a CutSet's features to reflect the new dataset-specific Updates the storage_path in a CutSet's features to reflect the new dataset-specific
@ -27,28 +28,33 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
try: try:
relative_path = original_storage_path.relative_to(old_feature_prefix) relative_path = original_storage_path.relative_to(old_feature_prefix)
except ValueError: except ValueError:
# If for some reason the path doesn't start with old_feature_prefix, # If for some reason the path doesn't start with old_feature_prefix,
# keep it as is. This can happen if some paths are already absolute or different. # keep it as is. This can happen if some paths are already absolute or different.
logger.warning(f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut.") logger.warning(
f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
)
updated_cuts.append(cut) updated_cuts.append(cut)
continue continue
# Avoid double-nesting (e.g., reazonspeech/reazonspeech/...) # Avoid double-nesting (e.g., reazonspeech/reazonspeech/...)
# Construct the new path: data/manifests/<dataset_name>/feats_train/feats-12.lca # Construct the new path: data/manifests/<dataset_name>/feats_train/feats-12.lca
if relative_path.parts[0] == dataset_name: if relative_path.parts[0] == dataset_name:
new_storage_path = Path("data/manifests") / relative_path new_storage_path = Path("data/manifests") / relative_path
else: else:
new_storage_path = Path("data/manifests") / dataset_name / relative_path new_storage_path = Path("data/manifests") / dataset_name / relative_path
logger.info(f"Updating cut {cut.id}: {original_storage_path}{new_storage_path}") logger.info(
f"Updating cut {cut.id}: {original_storage_path}{new_storage_path}"
)
cut.features.storage_path = str(new_storage_path) cut.features.storage_path = str(new_storage_path)
updated_cuts.append(cut) updated_cuts.append(cut)
else: else:
logger.warning(f"Skipping update for cut {cut.id}: has no features.") logger.warning(f"Skipping update for cut {cut.id}: has no features.")
updated_cuts.append(cut) # No features, or not a path we need to modify updated_cuts.append(cut) # No features, or not a path we need to modify
return CutSet.from_cuts(updated_cuts) return CutSet.from_cuts(updated_cuts)
if __name__ == "__main__": if __name__ == "__main__":
# The root where the symlinked manifests are located in the multi_ja_en recipe # The root where the symlinked manifests are located in the multi_ja_en recipe
multi_recipe_manifests_root = Path("data/manifests") multi_recipe_manifests_root = Path("data/manifests")
@ -71,55 +77,68 @@ if __name__ == "__main__":
if musan_manifest_path.exists(): if musan_manifest_path.exists():
logger.info(f"Processing musan manifest: {musan_manifest_path}") logger.info(f"Processing musan manifest: {musan_manifest_path}")
try: try:
musan_cuts = load_manifest(musan_manifest_path) musan_cuts = load_manifest(musan_manifest_path)
updated_musan_cuts = update_paths( updated_musan_cuts = update_paths(
musan_cuts, musan_cuts, "musan", old_feature_prefix="data/fbank"
"musan", )
old_feature_prefix="data/fbank" # Make sure we're overwriting the correct path even if it's a symlink
) if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
# Make sure we're overwriting the correct path even if it's a symlink logger.info(
if musan_manifest_path.is_symlink() or musan_manifest_path.exists(): f"Overwriting existing musan manifest at: {musan_manifest_path}"
logger.info(f"Overwriting existing musan manifest at: {musan_manifest_path}") )
os.unlink(musan_manifest_path) os.unlink(musan_manifest_path)
updated_musan_cuts.to_file(musan_manifest_path)
updated_musan_cuts.to_file(musan_manifest_path) logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
except Exception as e: except Exception as e:
logger.error(f"Error processing musan manifest {musan_manifest_path}: {e}", exc_info=True) logger.error(
f"Error processing musan manifest {musan_manifest_path}: {e}",
exc_info=True,
)
else: else:
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.") logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items(): for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
if not dataset_symlink_dir.is_dir(): if not dataset_symlink_dir.is_dir():
logger.warning(f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}.") logger.warning(
f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
)
continue continue
for split in splits: for split in splits:
# Construct the path to the symlinked manifest file # Construct the path to the symlinked manifest file
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz" manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
symlink_path = dataset_symlink_dir / manifest_filename # This is the path to the symlink itself symlink_path = (
dataset_symlink_dir / manifest_filename
) # This is the path to the symlink itself
if symlink_path.is_symlink(): # Check if it's actually a symlink if symlink_path.is_symlink(): # Check if it's actually a symlink
# Get the actual path to the target file that the symlink points to # Get the actual path to the target file that the symlink points to
# Lhotse's load_manifest will follow this symlink automatically. # Lhotse's load_manifest will follow this symlink automatically.
target_path = os.path.realpath(symlink_path) target_path = os.path.realpath(symlink_path)
logger.info(f"Processing symlink '{symlink_path}' pointing to '{target_path}'") logger.info(
elif symlink_path.is_file(): # If it's a regular file (not a symlink) f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
)
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
logger.info(f"Processing regular file: {symlink_path}") logger.info(f"Processing regular file: {symlink_path}")
target_path = symlink_path # Use its own path as target target_path = symlink_path # Use its own path as target
else: else:
logger.warning(f"Manifest file not found or neither a file nor a symlink: {symlink_path}") logger.warning(
continue # Skip to next iteration f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
)
continue # Skip to next iteration
try: try:
# Load the manifest. Lhotse will resolve the symlink internally for reading. # Load the manifest. Lhotse will resolve the symlink internally for reading.
cuts = load_manifest(symlink_path) # Use symlink_path here, Lhotse handles resolution for loading cuts = load_manifest(
symlink_path
) # Use symlink_path here, Lhotse handles resolution for loading
# Update the storage_path within the loaded cuts (in memory) # Update the storage_path within the loaded cuts (in memory)
updated_cuts = update_paths(cuts, dataset_name, old_feature_prefix=original_feature_base_path) updated_cuts = update_paths(
cuts, dataset_name, old_feature_prefix=original_feature_base_path
)
# --- CRITICAL CHANGE HERE --- # --- CRITICAL CHANGE HERE ---
# Save the *modified* CutSet to the path of the symlink *itself*. # Save the *modified* CutSet to the path of the symlink *itself*.
@ -127,7 +146,9 @@ if __name__ == "__main__":
# breaking the symlink and creating a new file in its place. # breaking the symlink and creating a new file in its place.
os.unlink(symlink_path) os.unlink(symlink_path)
updated_cuts.to_file(symlink_path) updated_cuts.to_file(symlink_path)
logger.info(f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}") logger.info(
f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
)
except Exception as e: except Exception as e:
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True) logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)

View File

@ -779,7 +779,7 @@ def main():
# ] # ]
# for test_set, test_dl in zip(test_sets, test_dl): # for test_set, test_dl in zip(test_sets, test_dl):
logging.info("Start decoding test set")#: {test_set}") logging.info("Start decoding test set") #: {test_set}")
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,