black and isort formatting

This commit is contained in:
Bailey Hirota 2025-07-16 19:53:47 +09:00
parent 2f1f419149
commit 7b4abbaaac
5 changed files with 73 additions and 45 deletions

View File

@ -213,7 +213,7 @@ def main():
args = get_args()
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bbpe.model"
if not model_file.is_file():
raise FileNotFoundError(f"BPE model not found at: {model_file}")

View File

@ -134,4 +134,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import argparse
import inspect
import logging
@ -23,6 +22,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
@ -39,6 +39,7 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
@ -46,6 +47,7 @@ class _SeedWorkers:
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class MultiDatasetAsrDataModule:
"""
DataModule for k2 ASR experiments.
@ -202,15 +204,19 @@ class MultiDatasetAsrDataModule:
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan/musan_cuts.jsonl.gz")
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True))
cuts_musan = load_manifest(
self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
@ -218,9 +224,10 @@ class MultiDatasetAsrDataModule:
)
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap)
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")

View File

@ -1,12 +1,13 @@
import logging
import os # Import os module to handle symlinks
from pathlib import Path
import os # Import os module to handle symlinks
from lhotse import CutSet, load_manifest
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
"""
Updates the storage_path in a CutSet's features to reflect the new dataset-specific
@ -27,28 +28,33 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
try:
relative_path = original_storage_path.relative_to(old_feature_prefix)
except ValueError:
# If for some reason the path doesn't start with old_feature_prefix,
# keep it as is. This can happen if some paths are already absolute or different.
logger.warning(f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut.")
# If for some reason the path doesn't start with old_feature_prefix,
# keep it as is. This can happen if some paths are already absolute or different.
logger.warning(
f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
)
updated_cuts.append(cut)
continue
# Avoid double-nesting (e.g., reazonspeech/reazonspeech/...)
# Construct the new path: data/manifests/<dataset_name>/feats_train/feats-12.lca
if relative_path.parts[0] == dataset_name:
new_storage_path = Path("data/manifests") / relative_path
if relative_path.parts[0] == dataset_name:
new_storage_path = Path("data/manifests") / relative_path
else:
new_storage_path = Path("data/manifests") / dataset_name / relative_path
logger.info(f"Updating cut {cut.id}: {original_storage_path}{new_storage_path}")
new_storage_path = Path("data/manifests") / dataset_name / relative_path
logger.info(
f"Updating cut {cut.id}: {original_storage_path}{new_storage_path}"
)
cut.features.storage_path = str(new_storage_path)
updated_cuts.append(cut)
else:
logger.warning(f"Skipping update for cut {cut.id}: has no features.")
updated_cuts.append(cut) # No features, or not a path we need to modify
updated_cuts.append(cut) # No features, or not a path we need to modify
return CutSet.from_cuts(updated_cuts)
if __name__ == "__main__":
# The root where the symlinked manifests are located in the multi_ja_en recipe
multi_recipe_manifests_root = Path("data/manifests")
@ -71,55 +77,68 @@ if __name__ == "__main__":
if musan_manifest_path.exists():
logger.info(f"Processing musan manifest: {musan_manifest_path}")
try:
musan_cuts = load_manifest(musan_manifest_path)
updated_musan_cuts = update_paths(
musan_cuts,
"musan",
old_feature_prefix="data/fbank"
)
# Make sure we're overwriting the correct path even if it's a symlink
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
logger.info(f"Overwriting existing musan manifest at: {musan_manifest_path}")
os.unlink(musan_manifest_path)
updated_musan_cuts.to_file(musan_manifest_path)
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
musan_cuts = load_manifest(musan_manifest_path)
updated_musan_cuts = update_paths(
musan_cuts, "musan", old_feature_prefix="data/fbank"
)
# Make sure we're overwriting the correct path even if it's a symlink
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
logger.info(
f"Overwriting existing musan manifest at: {musan_manifest_path}"
)
os.unlink(musan_manifest_path)
updated_musan_cuts.to_file(musan_manifest_path)
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
except Exception as e:
logger.error(f"Error processing musan manifest {musan_manifest_path}: {e}", exc_info=True)
logger.error(
f"Error processing musan manifest {musan_manifest_path}: {e}",
exc_info=True,
)
else:
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
if not dataset_symlink_dir.is_dir():
logger.warning(f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}.")
logger.warning(
f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
)
continue
for split in splits:
# Construct the path to the symlinked manifest file
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
symlink_path = dataset_symlink_dir / manifest_filename # This is the path to the symlink itself
symlink_path = (
dataset_symlink_dir / manifest_filename
) # This is the path to the symlink itself
if symlink_path.is_symlink(): # Check if it's actually a symlink
if symlink_path.is_symlink(): # Check if it's actually a symlink
# Get the actual path to the target file that the symlink points to
# Lhotse's load_manifest will follow this symlink automatically.
target_path = os.path.realpath(symlink_path)
logger.info(f"Processing symlink '{symlink_path}' pointing to '{target_path}'")
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
logger.info(
f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
)
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
logger.info(f"Processing regular file: {symlink_path}")
target_path = symlink_path # Use its own path as target
target_path = symlink_path # Use its own path as target
else:
logger.warning(f"Manifest file not found or neither a file nor a symlink: {symlink_path}")
continue # Skip to next iteration
logger.warning(
f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
)
continue # Skip to next iteration
try:
# Load the manifest. Lhotse will resolve the symlink internally for reading.
cuts = load_manifest(symlink_path) # Use symlink_path here, Lhotse handles resolution for loading
cuts = load_manifest(
symlink_path
) # Use symlink_path here, Lhotse handles resolution for loading
# Update the storage_path within the loaded cuts (in memory)
updated_cuts = update_paths(cuts, dataset_name, old_feature_prefix=original_feature_base_path)
updated_cuts = update_paths(
cuts, dataset_name, old_feature_prefix=original_feature_base_path
)
# --- CRITICAL CHANGE HERE ---
# Save the *modified* CutSet to the path of the symlink *itself*.
@ -127,7 +146,9 @@ if __name__ == "__main__":
# breaking the symlink and creating a new file in its place.
os.unlink(symlink_path)
updated_cuts.to_file(symlink_path)
logger.info(f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}")
logger.info(
f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
)
except Exception as e:
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)

View File

@ -779,7 +779,7 @@ def main():
# ]
# for test_set, test_dl in zip(test_sets, test_dl):
logging.info("Start decoding test set")#: {test_set}")
logging.info("Start decoding test set") #: {test_set}")
results_dict = decode_dataset(
dl=test_dl,