diff --git a/egs/tedlium3/ASR/.lora.sh.swp b/egs/tedlium3/ASR/.lora.sh.swp index 02c438332..debe70d26 100644 Binary files a/egs/tedlium3/ASR/.lora.sh.swp and b/egs/tedlium3/ASR/.lora.sh.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.asr_datamodule.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.asr_datamodule.py.swp index e86a6fce2..41504430c 100644 Binary files a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.asr_datamodule.py.swp and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.asr_datamodule.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.asr_datamodule_libri.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.asr_datamodule_libri.py.swp new file mode 100644 index 000000000..2b4a8325e Binary files /dev/null and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.asr_datamodule_libri.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp index 4cf33358e..c64561583 100644 Binary files a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/asr_datamodule.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/asr_datamodule.py index 387be7bca..4d3e0f93e 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/asr_datamodule.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/asr_datamodule.py @@ -31,7 +31,10 @@ from lhotse.dataset import ( SingleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + AudioSamples, +) from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -168,6 +171,18 @@ class TedLiumAsrDataModule: help="When enabled, select noise from MUSAN and mix it" "with training dataset.", ) + group.add_argument( + "--input-strategy", + type=str, + default="AudioSamples", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--spk-id", + type=int, + default=0, + ) def train_dataloaders( self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None @@ -238,13 +253,15 @@ class TedLiumAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + #input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=eval(self.args.input_strategy)(), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) else: train = K2SpeechRecognitionDataset( cut_transforms=transforms, + input_strategy=eval(self.args.input_strategy)(), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -295,12 +312,14 @@ class TedLiumAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=eval(self.args.input_strategy)(), + #input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, + input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -326,11 +345,13 @@ class TedLiumAsrDataModule: logging.debug("About to create test dataset") if self.args.on_the_fly_feats: test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + #input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) else: test = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -368,6 +389,6 @@ class TedLiumAsrDataModule: return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz") @lru_cache() - def user_test_cuts(self, user) -> CutSet: + def user_test_cuts(self, spk_id) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{user}.jsonl.gz") + return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{spk_id}.jsonl.gz") diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py index 88b5617fd..143a23f09 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py @@ -20,27 +20,6 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_ctc/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_ctc/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_ctc/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_ctc/exp \ - --full-libri 1 \ - --max-duration 550 - # For d2v-T training: export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" @@ -87,7 +66,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import TedLiumAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -1365,7 +1344,8 @@ def run(rank, world_size, args, wb=None): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) + #librispeech = LibriSpeechAsrDataModule(args) + ted = TedLiumAsrDataModule(args) train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: