mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
black and isort formatting
This commit is contained in:
parent
154ef43206
commit
6012edbc17
@ -213,7 +213,7 @@ def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
model_file = lang_dir / "bbpe.model"
|
||||
|
||||
|
||||
if not model_file.is_file():
|
||||
raise FileNotFoundError(f"BPE model not found at: {model_file}")
|
||||
|
||||
|
||||
@ -134,4 +134,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
@ -23,6 +22,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
@ -39,6 +39,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
@ -46,6 +47,7 @@ class _SeedWorkers:
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class MultiDatasetAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
@ -202,15 +204,19 @@ class MultiDatasetAsrDataModule:
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan/musan_cuts.jsonl.gz")
|
||||
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True))
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
|
||||
)
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
@ -218,9 +224,10 @@ class MultiDatasetAsrDataModule:
|
||||
)
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap)
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import logging
|
||||
import os # Import os module to handle symlinks
|
||||
from pathlib import Path
|
||||
import os # Import os module to handle symlinks
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
|
||||
"""
|
||||
Updates the storage_path in a CutSet's features to reflect the new dataset-specific
|
||||
@ -27,28 +28,33 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
|
||||
try:
|
||||
relative_path = original_storage_path.relative_to(old_feature_prefix)
|
||||
except ValueError:
|
||||
# If for some reason the path doesn't start with old_feature_prefix,
|
||||
# keep it as is. This can happen if some paths are already absolute or different.
|
||||
logger.warning(f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut.")
|
||||
# If for some reason the path doesn't start with old_feature_prefix,
|
||||
# keep it as is. This can happen if some paths are already absolute or different.
|
||||
logger.warning(
|
||||
f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
|
||||
)
|
||||
updated_cuts.append(cut)
|
||||
continue
|
||||
|
||||
# Avoid double-nesting (e.g., reazonspeech/reazonspeech/...)
|
||||
# Construct the new path: data/manifests/<dataset_name>/feats_train/feats-12.lca
|
||||
if relative_path.parts[0] == dataset_name:
|
||||
new_storage_path = Path("data/manifests") / relative_path
|
||||
if relative_path.parts[0] == dataset_name:
|
||||
new_storage_path = Path("data/manifests") / relative_path
|
||||
else:
|
||||
new_storage_path = Path("data/manifests") / dataset_name / relative_path
|
||||
|
||||
logger.info(f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}")
|
||||
new_storage_path = Path("data/manifests") / dataset_name / relative_path
|
||||
|
||||
logger.info(
|
||||
f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}"
|
||||
)
|
||||
cut.features.storage_path = str(new_storage_path)
|
||||
updated_cuts.append(cut)
|
||||
else:
|
||||
logger.warning(f"Skipping update for cut {cut.id}: has no features.")
|
||||
updated_cuts.append(cut) # No features, or not a path we need to modify
|
||||
updated_cuts.append(cut) # No features, or not a path we need to modify
|
||||
|
||||
return CutSet.from_cuts(updated_cuts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The root where the symlinked manifests are located in the multi_ja_en recipe
|
||||
multi_recipe_manifests_root = Path("data/manifests")
|
||||
@ -71,55 +77,68 @@ if __name__ == "__main__":
|
||||
if musan_manifest_path.exists():
|
||||
logger.info(f"Processing musan manifest: {musan_manifest_path}")
|
||||
try:
|
||||
musan_cuts = load_manifest(musan_manifest_path)
|
||||
updated_musan_cuts = update_paths(
|
||||
musan_cuts,
|
||||
"musan",
|
||||
old_feature_prefix="data/fbank"
|
||||
)
|
||||
# Make sure we're overwriting the correct path even if it's a symlink
|
||||
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
|
||||
logger.info(f"Overwriting existing musan manifest at: {musan_manifest_path}")
|
||||
os.unlink(musan_manifest_path)
|
||||
|
||||
updated_musan_cuts.to_file(musan_manifest_path)
|
||||
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
|
||||
musan_cuts = load_manifest(musan_manifest_path)
|
||||
updated_musan_cuts = update_paths(
|
||||
musan_cuts, "musan", old_feature_prefix="data/fbank"
|
||||
)
|
||||
# Make sure we're overwriting the correct path even if it's a symlink
|
||||
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
|
||||
logger.info(
|
||||
f"Overwriting existing musan manifest at: {musan_manifest_path}"
|
||||
)
|
||||
os.unlink(musan_manifest_path)
|
||||
updated_musan_cuts.to_file(musan_manifest_path)
|
||||
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing musan manifest {musan_manifest_path}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error processing musan manifest {musan_manifest_path}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
|
||||
|
||||
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
|
||||
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
|
||||
if not dataset_symlink_dir.is_dir():
|
||||
logger.warning(f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}.")
|
||||
logger.warning(
|
||||
f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
|
||||
)
|
||||
continue
|
||||
|
||||
for split in splits:
|
||||
# Construct the path to the symlinked manifest file
|
||||
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
|
||||
symlink_path = dataset_symlink_dir / manifest_filename # This is the path to the symlink itself
|
||||
symlink_path = (
|
||||
dataset_symlink_dir / manifest_filename
|
||||
) # This is the path to the symlink itself
|
||||
|
||||
if symlink_path.is_symlink(): # Check if it's actually a symlink
|
||||
if symlink_path.is_symlink(): # Check if it's actually a symlink
|
||||
# Get the actual path to the target file that the symlink points to
|
||||
# Lhotse's load_manifest will follow this symlink automatically.
|
||||
target_path = os.path.realpath(symlink_path)
|
||||
logger.info(f"Processing symlink '{symlink_path}' pointing to '{target_path}'")
|
||||
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
|
||||
logger.info(
|
||||
f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
|
||||
)
|
||||
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
|
||||
logger.info(f"Processing regular file: {symlink_path}")
|
||||
target_path = symlink_path # Use its own path as target
|
||||
target_path = symlink_path # Use its own path as target
|
||||
else:
|
||||
logger.warning(f"Manifest file not found or neither a file nor a symlink: {symlink_path}")
|
||||
continue # Skip to next iteration
|
||||
|
||||
logger.warning(
|
||||
f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
|
||||
)
|
||||
continue # Skip to next iteration
|
||||
|
||||
try:
|
||||
# Load the manifest. Lhotse will resolve the symlink internally for reading.
|
||||
cuts = load_manifest(symlink_path) # Use symlink_path here, Lhotse handles resolution for loading
|
||||
cuts = load_manifest(
|
||||
symlink_path
|
||||
) # Use symlink_path here, Lhotse handles resolution for loading
|
||||
|
||||
# Update the storage_path within the loaded cuts (in memory)
|
||||
updated_cuts = update_paths(cuts, dataset_name, old_feature_prefix=original_feature_base_path)
|
||||
updated_cuts = update_paths(
|
||||
cuts, dataset_name, old_feature_prefix=original_feature_base_path
|
||||
)
|
||||
|
||||
# --- CRITICAL CHANGE HERE ---
|
||||
# Save the *modified* CutSet to the path of the symlink *itself*.
|
||||
@ -127,7 +146,9 @@ if __name__ == "__main__":
|
||||
# breaking the symlink and creating a new file in its place.
|
||||
os.unlink(symlink_path)
|
||||
updated_cuts.to_file(symlink_path)
|
||||
logger.info(f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}")
|
||||
logger.info(
|
||||
f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)
|
||||
|
||||
@ -779,7 +779,7 @@ def main():
|
||||
# ]
|
||||
|
||||
# for test_set, test_dl in zip(test_sets, test_dl):
|
||||
logging.info("Start decoding test set")#: {test_set}")
|
||||
logging.info("Start decoding test set") #: {test_set}")
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user