diff --git a/egs/mls_english/ASR/local/train_bpe_model.py b/egs/mls_english/ASR/local/train_bpe_model.py index d62c18ee2..59e79be1e 100644 --- a/egs/mls_english/ASR/local/train_bpe_model.py +++ b/egs/mls_english/ASR/local/train_bpe_model.py @@ -111,4 +111,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/mls_english/ASR/local/utils/asr_datamodule.py b/egs/mls_english/ASR/local/utils/asr_datamodule.py index 6a64198d9..23e50fe02 100644 --- a/egs/mls_english/ASR/local/utils/asr_datamodule.py +++ b/egs/mls_english/ASR/local/utils/asr_datamodule.py @@ -36,6 +36,7 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool + class MLSEnglishHFAsrDataModule: """ DataModule for MLS English ASR experiments using HuggingFace dataset. @@ -46,6 +47,7 @@ class MLSEnglishHFAsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args self.dataset = None + # self._validate_args() # def _validate_args(self) -> None: @@ -59,7 +61,7 @@ class MLSEnglishHFAsrDataModule: title="ASR data related options", description="Options for data loading and processing", ) - + # Dataset configuration group.add_argument( "--dataset-path", @@ -67,7 +69,7 @@ class MLSEnglishHFAsrDataModule: default="parler-tts/mls_eng", help="Path to HuggingFace MLS English dataset (name or local path)", ) - + # Sampling and batching group.add_argument( "--max-duration", @@ -87,7 +89,7 @@ class MLSEnglishHFAsrDataModule: default=30, help="Number of buckets for DynamicBucketingSampler", ) - + # Data augmentation group.add_argument( "--enable-spec-aug", @@ -101,7 +103,7 @@ class MLSEnglishHFAsrDataModule: default=80, help="Time warp factor for SpecAugment", ) - + # Dataloader configuration group.add_argument( "--num-workers", @@ -122,7 +124,6 @@ class MLSEnglishHFAsrDataModule: default=True, help="Whether to drop last incomplete batch", ) - return parser @@ -133,16 +134,17 @@ class MLSEnglishHFAsrDataModule: try: from datasets import load_dataset + self.dataset = load_dataset(dataset_path) logging.info("Dataset loaded successfully") except ImportError: - raise ImportError( - "Please install datasets package: pip install datasets" - ) + raise ImportError("Please install datasets package: pip install datasets") except Exception as e: raise RuntimeError(f"Failed to load dataset: {e}") - def _create_dataset(self, cuts: CutSet, is_train: bool = False) -> K2SpeechRecognitionDataset: + def _create_dataset( + self, cuts: CutSet, is_train: bool = False + ) -> K2SpeechRecognitionDataset: """Create appropriate dataset with transforms.""" transforms = [] input_transforms = [] @@ -160,9 +162,9 @@ class MLSEnglishHFAsrDataModule: def _create_spec_augment(self) -> SpecAugment: """Create SpecAugment transform based on config.""" num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] + num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[ + "num_frame_masks" + ] if num_frame_masks_parameter.default == 1: num_frame_masks = 2 @@ -174,7 +176,9 @@ class MLSEnglishHFAsrDataModule: frames_mask_size=100, ) - def _create_sampler(self, cuts: CutSet, shuffle: bool) -> Union[DynamicBucketingSampler, SimpleCutSampler]: + def _create_sampler( + self, cuts: CutSet, shuffle: bool + ) -> Union[DynamicBucketingSampler, SimpleCutSampler]: """Create appropriate sampler based on config.""" if self.args.bucketing_sampler: return DynamicBucketingSampler( @@ -190,7 +194,9 @@ class MLSEnglishHFAsrDataModule: shuffle=shuffle, ) - def train_dataloader(self, sampler_state_dict: Optional[Dict[str, Any]] = None) -> DataLoader: + def train_dataloader( + self, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: """Create train dataloader.""" cuts = self.train_cuts() dataset = self._create_dataset(cuts, is_train=True) @@ -231,20 +237,17 @@ class MLSEnglishHFAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: return CutSet.from_huggingface_dataset( - self.dataset["train"], - text_key="transcript" + self.dataset["train"], text_key="transcript" ) @lru_cache() def valid_cuts(self) -> CutSet: return CutSet.from_huggingface_dataset( - self.dataset["dev"], - text_key="transcript" + self.dataset["dev"], text_key="transcript" ) @lru_cache() def test_cuts(self) -> CutSet: return CutSet.from_huggingface_dataset( - self.dataset["test"], - text_key="transcript" - ) \ No newline at end of file + self.dataset["test"], text_key="transcript" + ) diff --git a/egs/mls_english/ASR/local/utils/compute_fbank_mls_english.py b/egs/mls_english/ASR/local/utils/compute_fbank_mls_english.py index 807115e31..a7a9ca391 100644 --- a/egs/mls_english/ASR/local/utils/compute_fbank_mls_english.py +++ b/egs/mls_english/ASR/local/utils/compute_fbank_mls_english.py @@ -49,7 +49,7 @@ concat_params = {"gap": 1.0, "maxlen": 10.0} def make_cutset_blueprints( - mls_eng_hf_dataset_path: str = "parler-tts/mls_eng" + mls_eng_hf_dataset_path: str = "parler-tts/mls_eng", ) -> List[Tuple[str, CutSet]]: cut_sets = [] @@ -57,7 +57,7 @@ def make_cutset_blueprints( raise ImportError( "To process the MLS English HF corpus, please install optional dependency: pip install datasets" ) - + from datasets import load_dataset dataset = load_dataset(mls_eng_hf_dataset_path) @@ -67,17 +67,14 @@ def make_cutset_blueprints( cut_sets.append( ( "test", - CutSet.from_huggingface_dataset(dataset["test"], text_key="transcript") + CutSet.from_huggingface_dataset(dataset["test"], text_key="transcript"), ) ) # Create dev dataset logging.info("Creating dev cuts.") cut_sets.append( - ( - "dev", - CutSet.from_huggingface_dataset(dataset["dev"], text_key="transcript") - ) + ("dev", CutSet.from_huggingface_dataset(dataset["dev"], text_key="transcript")) ) # Create train dataset @@ -85,7 +82,7 @@ def make_cutset_blueprints( cut_sets.append( ( "train", - CutSet.from_huggingface_dataset(dataset["train"], text_key="transcript") + CutSet.from_huggingface_dataset(dataset["train"], text_key="transcript"), ) ) return cut_sets @@ -127,7 +124,7 @@ def main(): storage_path=(args.manifest_dir / f"feats_{part}").as_posix(), storage_type=LilcomChunkyWriter, ) - + # cut_set.save_audios(args.audio_dir) # cut_set.to_file(args.manifest_dir / f"mls_eng_cuts_{part}.jsonl.gz") diff --git a/egs/mls_english/ASR/local/utils/generate_transcript.py b/egs/mls_english/ASR/local/utils/generate_transcript.py index f48093bcb..bf2ab53de 100644 --- a/egs/mls_english/ASR/local/utils/generate_transcript.py +++ b/egs/mls_english/ASR/local/utils/generate_transcript.py @@ -24,6 +24,7 @@ from typing import Optional from lhotse import CutSet from tqdm import tqdm + def get_args(): parser = argparse.ArgumentParser( description="Generate transcripts for BPE training from MLS English dataset", @@ -36,14 +37,14 @@ def get_args(): default="parler-tts/mls_eng", help="Path to HuggingFace MLS English dataset (name or local path)", ) - + parser.add_argument( "--lang-dir", type=Path, default=Path("data/lang"), help="Directory to store output transcripts", ) - + parser.add_argument( "--split", type=str, @@ -53,6 +54,7 @@ def get_args(): return parser.parse_args() + def generate_transcript_from_cuts(cuts: CutSet, output_file: Path) -> None: """Generate transcript text file from Lhotse CutSet.""" with open(output_file, "w") as f: @@ -60,6 +62,7 @@ def generate_transcript_from_cuts(cuts: CutSet, output_file: Path) -> None: for sup in cut.supervisions: f.write(f"{sup.text}\n") + def main(): args = get_args() logging.basicConfig( @@ -73,9 +76,7 @@ def main(): logging.info(f"Loading {args.split} split from dataset: {args.dataset_path}") try: cuts = CutSet.from_huggingface_dataset( - args.dataset_path, - split=args.split, - text_key="transcript" + args.dataset_path, split=args.split, text_key="transcript" ) except Exception as e: logging.error(f"Failed to load dataset: {e}") @@ -85,5 +86,6 @@ def main(): generate_transcript_from_cuts(cuts, output_file) logging.info("Transcript generation completed") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/mls_english/ASR/prepare.sh b/egs/mls_english/ASR/prepare.sh index eb42510b9..8a46da774 100644 --- a/egs/mls_english/ASR/prepare.sh +++ b/egs/mls_english/ASR/prepare.sh @@ -69,4 +69,4 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then done fi -log "MLS English data preparation completed successfully" \ No newline at end of file +log "MLS English data preparation completed successfully" diff --git a/egs/mls_english/ASR/zipformer/decode.py b/egs/mls_english/ASR/zipformer/decode.py index 088252c59..fc8de5d64 100755 --- a/egs/mls_english/ASR/zipformer/decode.py +++ b/egs/mls_english/ASR/zipformer/decode.py @@ -103,9 +103,6 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -# import sentencepiece as spm -from tokenizer import Tokenizer - import torch import torch.nn as nn from asr_datamodule import MLSEnglishHFAsrDataModule @@ -123,6 +120,10 @@ from beam_search import ( modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) + +# import sentencepiece as spm +from tokenizer import Tokenizer + # from gigaspeech_scoring import asr_text_post_processing from train import add_model_arguments, get_model, get_params @@ -384,6 +385,7 @@ def get_parser(): return parser + def asr_text_post_processing(inp): return inp @@ -867,8 +869,7 @@ def main(): # sp = spm.SentencePieceProcessor() # sp.load(params.bpe_model) - sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model - + sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") diff --git a/egs/mls_english/ASR/zipformer/train.py b/egs/mls_english/ASR/zipformer/train.py index 8c61df8f2..7c6997656 100755 --- a/egs/mls_english/ASR/zipformer/train.py +++ b/egs/mls_english/ASR/zipformer/train.py @@ -1115,7 +1115,7 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model + sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model # is defined in local/prepare_lang_char.py params.blank_id = sp.piece_to_id("") @@ -1239,7 +1239,6 @@ def run(rank, world_size, args): # valid_dl = mls_english_corpus.valid_dataloader(valid_cuts) valid_dl = mls_english_corpus.valid_dataloader() - if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model,