diff --git a/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py index 31b5c4f2f..ad6bd5f40 100755 --- a/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py +++ b/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py @@ -213,7 +213,7 @@ def main(): args = get_args() lang_dir = Path(args.lang_dir) model_file = lang_dir / "bbpe.model" - + if not model_file.is_file(): raise FileNotFoundError(f"BPE model not found at: {model_file}") diff --git a/egs/multi_ja_en/ASR/local/train_bbpe_model.py b/egs/multi_ja_en/ASR/local/train_bbpe_model.py index e51193f3e..b87e6cd28 100755 --- a/egs/multi_ja_en/ASR/local/train_bbpe_model.py +++ b/egs/multi_ja_en/ASR/local/train_bbpe_model.py @@ -134,4 +134,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py index 5a9c3ba3a..d82375dd7 100644 --- a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py +++ b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import argparse import inspect import logging @@ -23,6 +22,7 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional +import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( CutConcatenate, @@ -39,6 +39,7 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool + class _SeedWorkers: def __init__(self, seed: int): self.seed = seed @@ -46,6 +47,7 @@ class _SeedWorkers: def __call__(self, worker_id: int): fix_random_seed(self.seed + worker_id) + class MultiDatasetAsrDataModule: """ DataModule for k2 ASR experiments. @@ -202,15 +204,19 @@ class MultiDatasetAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan/musan_cuts.jsonl.gz") - transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)) + cuts_musan = load_manifest( + self.args.manifest_dir / "musan/musan_cuts.jsonl.gz" + ) + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) else: logging.info("Disable MUSAN") # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between # different utterances. - + if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " @@ -218,9 +224,10 @@ class MultiDatasetAsrDataModule: ) transforms = [ CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap) + duration_factor=self.args.duration_factor, gap=self.args.gap + ) ] + transforms - + if self.args.enable_spec_aug: logging.info("Enable SpecAugment") logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") diff --git a/egs/multi_ja_en/ASR/local/utils/update_cutset_paths.py b/egs/multi_ja_en/ASR/local/utils/update_cutset_paths.py index 7cd5727cf..ce8d2805a 100644 --- a/egs/multi_ja_en/ASR/local/utils/update_cutset_paths.py +++ b/egs/multi_ja_en/ASR/local/utils/update_cutset_paths.py @@ -1,12 +1,13 @@ import logging +import os # Import os module to handle symlinks from pathlib import Path -import os # Import os module to handle symlinks from lhotse import CutSet, load_manifest logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet: """ Updates the storage_path in a CutSet's features to reflect the new dataset-specific @@ -27,28 +28,33 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu try: relative_path = original_storage_path.relative_to(old_feature_prefix) except ValueError: - # If for some reason the path doesn't start with old_feature_prefix, - # keep it as is. This can happen if some paths are already absolute or different. - logger.warning(f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut.") + # If for some reason the path doesn't start with old_feature_prefix, + # keep it as is. This can happen if some paths are already absolute or different. + logger.warning( + f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut." + ) updated_cuts.append(cut) continue # Avoid double-nesting (e.g., reazonspeech/reazonspeech/...) # Construct the new path: data/manifests//feats_train/feats-12.lca - if relative_path.parts[0] == dataset_name: - new_storage_path = Path("data/manifests") / relative_path + if relative_path.parts[0] == dataset_name: + new_storage_path = Path("data/manifests") / relative_path else: - new_storage_path = Path("data/manifests") / dataset_name / relative_path - - logger.info(f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}") + new_storage_path = Path("data/manifests") / dataset_name / relative_path + + logger.info( + f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}" + ) cut.features.storage_path = str(new_storage_path) updated_cuts.append(cut) else: logger.warning(f"Skipping update for cut {cut.id}: has no features.") - updated_cuts.append(cut) # No features, or not a path we need to modify + updated_cuts.append(cut) # No features, or not a path we need to modify return CutSet.from_cuts(updated_cuts) + if __name__ == "__main__": # The root where the symlinked manifests are located in the multi_ja_en recipe multi_recipe_manifests_root = Path("data/manifests") @@ -71,55 +77,68 @@ if __name__ == "__main__": if musan_manifest_path.exists(): logger.info(f"Processing musan manifest: {musan_manifest_path}") try: - musan_cuts = load_manifest(musan_manifest_path) - updated_musan_cuts = update_paths( - musan_cuts, - "musan", - old_feature_prefix="data/fbank" - ) - # Make sure we're overwriting the correct path even if it's a symlink - if musan_manifest_path.is_symlink() or musan_manifest_path.exists(): - logger.info(f"Overwriting existing musan manifest at: {musan_manifest_path}") - os.unlink(musan_manifest_path) - - updated_musan_cuts.to_file(musan_manifest_path) - logger.info(f"Updated musan cuts written to: {musan_manifest_path}") + musan_cuts = load_manifest(musan_manifest_path) + updated_musan_cuts = update_paths( + musan_cuts, "musan", old_feature_prefix="data/fbank" + ) + # Make sure we're overwriting the correct path even if it's a symlink + if musan_manifest_path.is_symlink() or musan_manifest_path.exists(): + logger.info( + f"Overwriting existing musan manifest at: {musan_manifest_path}" + ) + os.unlink(musan_manifest_path) + updated_musan_cuts.to_file(musan_manifest_path) + logger.info(f"Updated musan cuts written to: {musan_manifest_path}") except Exception as e: - logger.error(f"Error processing musan manifest {musan_manifest_path}: {e}", exc_info=True) + logger.error( + f"Error processing musan manifest {musan_manifest_path}: {e}", + exc_info=True, + ) else: logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.") for dataset_name, manifest_prefix in dataset_manifest_prefixes.items(): dataset_symlink_dir = multi_recipe_manifests_root / dataset_name if not dataset_symlink_dir.is_dir(): - logger.warning(f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}.") + logger.warning( + f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}." + ) continue for split in splits: # Construct the path to the symlinked manifest file manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz" - symlink_path = dataset_symlink_dir / manifest_filename # This is the path to the symlink itself + symlink_path = ( + dataset_symlink_dir / manifest_filename + ) # This is the path to the symlink itself - if symlink_path.is_symlink(): # Check if it's actually a symlink + if symlink_path.is_symlink(): # Check if it's actually a symlink # Get the actual path to the target file that the symlink points to # Lhotse's load_manifest will follow this symlink automatically. target_path = os.path.realpath(symlink_path) - logger.info(f"Processing symlink '{symlink_path}' pointing to '{target_path}'") - elif symlink_path.is_file(): # If it's a regular file (not a symlink) + logger.info( + f"Processing symlink '{symlink_path}' pointing to '{target_path}'" + ) + elif symlink_path.is_file(): # If it's a regular file (not a symlink) logger.info(f"Processing regular file: {symlink_path}") - target_path = symlink_path # Use its own path as target + target_path = symlink_path # Use its own path as target else: - logger.warning(f"Manifest file not found or neither a file nor a symlink: {symlink_path}") - continue # Skip to next iteration - + logger.warning( + f"Manifest file not found or neither a file nor a symlink: {symlink_path}" + ) + continue # Skip to next iteration try: # Load the manifest. Lhotse will resolve the symlink internally for reading. - cuts = load_manifest(symlink_path) # Use symlink_path here, Lhotse handles resolution for loading + cuts = load_manifest( + symlink_path + ) # Use symlink_path here, Lhotse handles resolution for loading # Update the storage_path within the loaded cuts (in memory) - updated_cuts = update_paths(cuts, dataset_name, old_feature_prefix=original_feature_base_path) + updated_cuts = update_paths( + cuts, dataset_name, old_feature_prefix=original_feature_base_path + ) # --- CRITICAL CHANGE HERE --- # Save the *modified* CutSet to the path of the symlink *itself*. @@ -127,7 +146,9 @@ if __name__ == "__main__": # breaking the symlink and creating a new file in its place. os.unlink(symlink_path) updated_cuts.to_file(symlink_path) - logger.info(f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}") + logger.info( + f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}" + ) except Exception as e: logger.error(f"Error processing {symlink_path}: {e}", exc_info=True) diff --git a/egs/multi_ja_en/ASR/zipformer/decode.py b/egs/multi_ja_en/ASR/zipformer/decode.py index 0b988b37f..b1fd44493 100755 --- a/egs/multi_ja_en/ASR/zipformer/decode.py +++ b/egs/multi_ja_en/ASR/zipformer/decode.py @@ -779,7 +779,7 @@ def main(): # ] # for test_set, test_dl in zip(test_sets, test_dl): - logging.info("Start decoding test set")#: {test_set}") + logging.info("Start decoding test set") #: {test_set}") results_dict = decode_dataset( dl=test_dl,