diff --git a/egs/librispeech/ASR/conformer_ctc2/.train.py.swp b/egs/librispeech/ASR/conformer_ctc2/.train.py.swp index b7ed26d19..3d9b44bb3 100644 Binary files a/egs/librispeech/ASR/conformer_ctc2/.train.py.swp and b/egs/librispeech/ASR/conformer_ctc2/.train.py.swp differ diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index aab2d2acb..146d10017 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -1066,9 +1066,9 @@ def run(rank, world_size, args): 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) - + + ''' tedlium = TedLiumAsrDataModule(args) - train_cuts = tedlium.train_cuts() if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -1084,6 +1084,54 @@ def run(rank, world_size, args): valid_cuts = tedlium.dev_cuts() valid_dl = tedlium.valid_dataloaders(valid_cuts) + ''' + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 30.0 + + def remove_invalid_utt_ctc(c: Cut): + # Caution: We assume the subsampling factor is 4! + # num_tokens = len(sp.encode(c.supervisions[0].text, out_type=int)) + num_tokens = len(graph_compiler.texts_to_ids(c.supervisions[0].text)) + min_output_input_ratio = 0.0005 + max_output_input_ratio = 0.1 + return ( + min_output_input_ratio + < num_tokens / float(c.features.num_frames) + < max_output_input_ratio + ) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.filter(remove_invalid_utt_ctc) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) if ( params.start_epoch <= 1