mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
black and isort formatting
This commit is contained in:
parent
2f1f419149
commit
7b4abbaaac
@ -213,7 +213,7 @@ def main():
|
|||||||
args = get_args()
|
args = get_args()
|
||||||
lang_dir = Path(args.lang_dir)
|
lang_dir = Path(args.lang_dir)
|
||||||
model_file = lang_dir / "bbpe.model"
|
model_file = lang_dir / "bbpe.model"
|
||||||
|
|
||||||
if not model_file.is_file():
|
if not model_file.is_file():
|
||||||
raise FileNotFoundError(f"BPE model not found at: {model_file}")
|
raise FileNotFoundError(f"BPE model not found at: {model_file}")
|
||||||
|
|
||||||
|
|||||||
@ -134,4 +134,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -15,7 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
|
||||||
import argparse
|
import argparse
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@ -23,6 +22,7 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
@ -39,6 +39,7 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@ -46,6 +47,7 @@ class _SeedWorkers:
|
|||||||
def __call__(self, worker_id: int):
|
def __call__(self, worker_id: int):
|
||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class MultiDatasetAsrDataModule:
|
class MultiDatasetAsrDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for k2 ASR experiments.
|
DataModule for k2 ASR experiments.
|
||||||
@ -202,15 +204,19 @@ class MultiDatasetAsrDataModule:
|
|||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan/musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True))
|
self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
|
transforms.append(
|
||||||
|
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
# Cut concatenation should be the first transform in the list,
|
# Cut concatenation should be the first transform in the list,
|
||||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||||
# different utterances.
|
# different utterances.
|
||||||
|
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
@ -218,9 +224,10 @@ class MultiDatasetAsrDataModule:
|
|||||||
)
|
)
|
||||||
transforms = [
|
transforms = [
|
||||||
CutConcatenate(
|
CutConcatenate(
|
||||||
duration_factor=self.args.duration_factor, gap=self.args.gap)
|
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||||
|
)
|
||||||
] + transforms
|
] + transforms
|
||||||
|
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os # Import os module to handle symlinks
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os # Import os module to handle symlinks
|
|
||||||
|
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
|
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
|
||||||
"""
|
"""
|
||||||
Updates the storage_path in a CutSet's features to reflect the new dataset-specific
|
Updates the storage_path in a CutSet's features to reflect the new dataset-specific
|
||||||
@ -27,28 +28,33 @@ def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> Cu
|
|||||||
try:
|
try:
|
||||||
relative_path = original_storage_path.relative_to(old_feature_prefix)
|
relative_path = original_storage_path.relative_to(old_feature_prefix)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If for some reason the path doesn't start with old_feature_prefix,
|
# If for some reason the path doesn't start with old_feature_prefix,
|
||||||
# keep it as is. This can happen if some paths are already absolute or different.
|
# keep it as is. This can happen if some paths are already absolute or different.
|
||||||
logger.warning(f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut.")
|
logger.warning(
|
||||||
|
f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
|
||||||
|
)
|
||||||
updated_cuts.append(cut)
|
updated_cuts.append(cut)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Avoid double-nesting (e.g., reazonspeech/reazonspeech/...)
|
# Avoid double-nesting (e.g., reazonspeech/reazonspeech/...)
|
||||||
# Construct the new path: data/manifests/<dataset_name>/feats_train/feats-12.lca
|
# Construct the new path: data/manifests/<dataset_name>/feats_train/feats-12.lca
|
||||||
if relative_path.parts[0] == dataset_name:
|
if relative_path.parts[0] == dataset_name:
|
||||||
new_storage_path = Path("data/manifests") / relative_path
|
new_storage_path = Path("data/manifests") / relative_path
|
||||||
else:
|
else:
|
||||||
new_storage_path = Path("data/manifests") / dataset_name / relative_path
|
new_storage_path = Path("data/manifests") / dataset_name / relative_path
|
||||||
|
|
||||||
logger.info(f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}")
|
logger.info(
|
||||||
|
f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}"
|
||||||
|
)
|
||||||
cut.features.storage_path = str(new_storage_path)
|
cut.features.storage_path = str(new_storage_path)
|
||||||
updated_cuts.append(cut)
|
updated_cuts.append(cut)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Skipping update for cut {cut.id}: has no features.")
|
logger.warning(f"Skipping update for cut {cut.id}: has no features.")
|
||||||
updated_cuts.append(cut) # No features, or not a path we need to modify
|
updated_cuts.append(cut) # No features, or not a path we need to modify
|
||||||
|
|
||||||
return CutSet.from_cuts(updated_cuts)
|
return CutSet.from_cuts(updated_cuts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# The root where the symlinked manifests are located in the multi_ja_en recipe
|
# The root where the symlinked manifests are located in the multi_ja_en recipe
|
||||||
multi_recipe_manifests_root = Path("data/manifests")
|
multi_recipe_manifests_root = Path("data/manifests")
|
||||||
@ -71,55 +77,68 @@ if __name__ == "__main__":
|
|||||||
if musan_manifest_path.exists():
|
if musan_manifest_path.exists():
|
||||||
logger.info(f"Processing musan manifest: {musan_manifest_path}")
|
logger.info(f"Processing musan manifest: {musan_manifest_path}")
|
||||||
try:
|
try:
|
||||||
musan_cuts = load_manifest(musan_manifest_path)
|
musan_cuts = load_manifest(musan_manifest_path)
|
||||||
updated_musan_cuts = update_paths(
|
updated_musan_cuts = update_paths(
|
||||||
musan_cuts,
|
musan_cuts, "musan", old_feature_prefix="data/fbank"
|
||||||
"musan",
|
)
|
||||||
old_feature_prefix="data/fbank"
|
# Make sure we're overwriting the correct path even if it's a symlink
|
||||||
)
|
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
|
||||||
# Make sure we're overwriting the correct path even if it's a symlink
|
logger.info(
|
||||||
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
|
f"Overwriting existing musan manifest at: {musan_manifest_path}"
|
||||||
logger.info(f"Overwriting existing musan manifest at: {musan_manifest_path}")
|
)
|
||||||
os.unlink(musan_manifest_path)
|
os.unlink(musan_manifest_path)
|
||||||
|
updated_musan_cuts.to_file(musan_manifest_path)
|
||||||
updated_musan_cuts.to_file(musan_manifest_path)
|
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
|
||||||
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing musan manifest {musan_manifest_path}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error processing musan manifest {musan_manifest_path}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
|
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
|
||||||
|
|
||||||
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
|
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
|
||||||
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
|
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
|
||||||
if not dataset_symlink_dir.is_dir():
|
if not dataset_symlink_dir.is_dir():
|
||||||
logger.warning(f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}.")
|
logger.warning(
|
||||||
|
f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for split in splits:
|
for split in splits:
|
||||||
# Construct the path to the symlinked manifest file
|
# Construct the path to the symlinked manifest file
|
||||||
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
|
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
|
||||||
symlink_path = dataset_symlink_dir / manifest_filename # This is the path to the symlink itself
|
symlink_path = (
|
||||||
|
dataset_symlink_dir / manifest_filename
|
||||||
|
) # This is the path to the symlink itself
|
||||||
|
|
||||||
if symlink_path.is_symlink(): # Check if it's actually a symlink
|
if symlink_path.is_symlink(): # Check if it's actually a symlink
|
||||||
# Get the actual path to the target file that the symlink points to
|
# Get the actual path to the target file that the symlink points to
|
||||||
# Lhotse's load_manifest will follow this symlink automatically.
|
# Lhotse's load_manifest will follow this symlink automatically.
|
||||||
target_path = os.path.realpath(symlink_path)
|
target_path = os.path.realpath(symlink_path)
|
||||||
logger.info(f"Processing symlink '{symlink_path}' pointing to '{target_path}'")
|
logger.info(
|
||||||
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
|
f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
|
||||||
|
)
|
||||||
|
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
|
||||||
logger.info(f"Processing regular file: {symlink_path}")
|
logger.info(f"Processing regular file: {symlink_path}")
|
||||||
target_path = symlink_path # Use its own path as target
|
target_path = symlink_path # Use its own path as target
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Manifest file not found or neither a file nor a symlink: {symlink_path}")
|
logger.warning(
|
||||||
continue # Skip to next iteration
|
f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
|
||||||
|
)
|
||||||
|
continue # Skip to next iteration
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the manifest. Lhotse will resolve the symlink internally for reading.
|
# Load the manifest. Lhotse will resolve the symlink internally for reading.
|
||||||
cuts = load_manifest(symlink_path) # Use symlink_path here, Lhotse handles resolution for loading
|
cuts = load_manifest(
|
||||||
|
symlink_path
|
||||||
|
) # Use symlink_path here, Lhotse handles resolution for loading
|
||||||
|
|
||||||
# Update the storage_path within the loaded cuts (in memory)
|
# Update the storage_path within the loaded cuts (in memory)
|
||||||
updated_cuts = update_paths(cuts, dataset_name, old_feature_prefix=original_feature_base_path)
|
updated_cuts = update_paths(
|
||||||
|
cuts, dataset_name, old_feature_prefix=original_feature_base_path
|
||||||
|
)
|
||||||
|
|
||||||
# --- CRITICAL CHANGE HERE ---
|
# --- CRITICAL CHANGE HERE ---
|
||||||
# Save the *modified* CutSet to the path of the symlink *itself*.
|
# Save the *modified* CutSet to the path of the symlink *itself*.
|
||||||
@ -127,7 +146,9 @@ if __name__ == "__main__":
|
|||||||
# breaking the symlink and creating a new file in its place.
|
# breaking the symlink and creating a new file in its place.
|
||||||
os.unlink(symlink_path)
|
os.unlink(symlink_path)
|
||||||
updated_cuts.to_file(symlink_path)
|
updated_cuts.to_file(symlink_path)
|
||||||
logger.info(f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}")
|
logger.info(
|
||||||
|
f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)
|
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)
|
||||||
|
|||||||
@ -779,7 +779,7 @@ def main():
|
|||||||
# ]
|
# ]
|
||||||
|
|
||||||
# for test_set, test_dl in zip(test_sets, test_dl):
|
# for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
logging.info("Start decoding test set")#: {test_set}")
|
logging.info("Start decoding test set") #: {test_set}")
|
||||||
|
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user