mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
black and isort formatting
This commit is contained in:
parent
154ef43206
commit
6012edbc17
@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
@ -23,6 +22,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
@ -39,6 +39,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
@ -46,6 +47,7 @@ class _SeedWorkers:
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class MultiDatasetAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
@ -202,8 +204,12 @@ class MultiDatasetAsrDataModule:
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan/musan_cuts.jsonl.gz")
|
||||
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True))
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
|
||||
)
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
@ -218,7 +224,8 @@ class MultiDatasetAsrDataModule:
|
||||
)
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap)
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
if self.args.enable_spec_aug:
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import os # Import os module to handle symlinks
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
|
||||
"""
|
||||
Updates the storage_path in a CutSet's features to reflect the new dataset-specific
|
||||
@ -29,7 +30,9 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
|
||||
except ValueError:
|
||||
# If for some reason the path doesn't start with old_feature_prefix,
|
||||
# keep it as is. This can happen if some paths are already absolute or different.
|
||||
logger.warning(f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut.")
|
||||
logger.warning(
|
||||
f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
|
||||
)
|
||||
updated_cuts.append(cut)
|
||||
continue
|
||||
|
||||
@ -40,7 +43,9 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
|
||||
else:
|
||||
new_storage_path = Path("data/manifests") / dataset_name / relative_path
|
||||
|
||||
logger.info(f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}")
|
||||
logger.info(
|
||||
f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}"
|
||||
)
|
||||
cut.features.storage_path = str(new_storage_path)
|
||||
updated_cuts.append(cut)
|
||||
else:
|
||||
@ -49,6 +54,7 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
|
||||
|
||||
return CutSet.from_cuts(updated_cuts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The root where the symlinked manifests are located in the multi_ja_en recipe
|
||||
multi_recipe_manifests_root = Path("data/manifests")
|
||||
@ -73,53 +79,66 @@ if __name__ == "__main__":
|
||||
try:
|
||||
musan_cuts = load_manifest(musan_manifest_path)
|
||||
updated_musan_cuts = update_paths(
|
||||
musan_cuts,
|
||||
"musan",
|
||||
old_feature_prefix="data/fbank"
|
||||
musan_cuts, "musan", old_feature_prefix="data/fbank"
|
||||
)
|
||||
# Make sure we're overwriting the correct path even if it's a symlink
|
||||
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
|
||||
logger.info(f"Overwriting existing musan manifest at: {musan_manifest_path}")
|
||||
logger.info(
|
||||
f"Overwriting existing musan manifest at: {musan_manifest_path}"
|
||||
)
|
||||
os.unlink(musan_manifest_path)
|
||||
|
||||
updated_musan_cuts.to_file(musan_manifest_path)
|
||||
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing musan manifest {musan_manifest_path}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error processing musan manifest {musan_manifest_path}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
|
||||
|
||||
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
|
||||
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
|
||||
if not dataset_symlink_dir.is_dir():
|
||||
logger.warning(f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}.")
|
||||
logger.warning(
|
||||
f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
|
||||
)
|
||||
continue
|
||||
|
||||
for split in splits:
|
||||
# Construct the path to the symlinked manifest file
|
||||
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
|
||||
symlink_path = dataset_symlink_dir / manifest_filename # This is the path to the symlink itself
|
||||
symlink_path = (
|
||||
dataset_symlink_dir / manifest_filename
|
||||
) # This is the path to the symlink itself
|
||||
|
||||
if symlink_path.is_symlink(): # Check if it's actually a symlink
|
||||
# Get the actual path to the target file that the symlink points to
|
||||
# Lhotse's load_manifest will follow this symlink automatically.
|
||||
target_path = os.path.realpath(symlink_path)
|
||||
logger.info(f"Processing symlink '{symlink_path}' pointing to '{target_path}'")
|
||||
logger.info(
|
||||
f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
|
||||
)
|
||||
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
|
||||
logger.info(f"Processing regular file: {symlink_path}")
|
||||
target_path = symlink_path # Use its own path as target
|
||||
else:
|
||||
logger.warning(f"Manifest file not found or neither a file nor a symlink: {symlink_path}")
|
||||
logger.warning(
|
||||
f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
|
||||
)
|
||||
continue # Skip to next iteration
|
||||
|
||||
|
||||
try:
|
||||
# Load the manifest. Lhotse will resolve the symlink internally for reading.
|
||||
cuts = load_manifest(symlink_path) # Use symlink_path here, Lhotse handles resolution for loading
|
||||
cuts = load_manifest(
|
||||
symlink_path
|
||||
) # Use symlink_path here, Lhotse handles resolution for loading
|
||||
|
||||
# Update the storage_path within the loaded cuts (in memory)
|
||||
updated_cuts = update_paths(cuts, dataset_name, old_feature_prefix=original_feature_base_path)
|
||||
updated_cuts = update_paths(
|
||||
cuts, dataset_name, old_feature_prefix=original_feature_base_path
|
||||
)
|
||||
|
||||
# --- CRITICAL CHANGE HERE ---
|
||||
# Save the *modified* CutSet to the path of the symlink *itself*.
|
||||
@ -127,7 +146,9 @@ if __name__ == "__main__":
|
||||
# breaking the symlink and creating a new file in its place.
|
||||
os.unlink(symlink_path)
|
||||
updated_cuts.to_file(symlink_path)
|
||||
logger.info(f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}")
|
||||
logger.info(
|
||||
f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)
|
||||
|
||||
@ -779,7 +779,7 @@ def main():
|
||||
# ]
|
||||
|
||||
# for test_set, test_dl in zip(test_sets, test_dl):
|
||||
logging.info("Start decoding test set")#: {test_set}")
|
||||
logging.info("Start decoding test set") #: {test_set}")
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user