diff --git a/egs/tedlium2/ASR/conformer_ctc2/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc2/.train.py.swp index 7e1995bf3..d054c650f 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc2/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc2/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc2/train.py b/egs/tedlium2/ASR/conformer_ctc2/train.py index 6c6149955..a3cf82b07 100755 --- a/egs/tedlium2/ASR/conformer_ctc2/train.py +++ b/egs/tedlium2/ASR/conformer_ctc2/train.py @@ -935,12 +935,14 @@ def run(rank, world_size, args): if params.print_diagnostics: diagnostic = diagnostics.attach_diagnostics(model) - librispeech = LibriSpeechAsrDataModule(args) + #librispeech = LibriSpeechAsrDataModule(args) + tedlium = TedAsrDataModule(args) - if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() + #if params.full_libri: + # train_cuts = librispeech.train_all_shuf_cuts() + #else: + # train_cuts = librispeech.train_clean_100_cuts() + train_cuts = tedlium.train_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -975,13 +977,20 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + #train_dl = librispeech.train_dataloaders( + # train_cuts, sampler_state_dict=sampler_state_dict + #) + + #valid_cuts = librispeech.dev_clean_cuts() + #valid_cuts += librispeech.dev_other_cuts() + #valid_dl = librispeech.valid_dataloaders(valid_cuts) + + train_dl = tedlium.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_cuts = tedlium.dev_cuts() + valid_dl = tedlium.valid_dataloaders(valid_cuts) if params.print_diagnostics: scan_pessimistic_batches_for_oom(