diff --git a/egs/tedlium3/ASR/.lora.sh.swp b/egs/tedlium3/ASR/.lora.sh.swp index debe70d26..688dbd961 100644 Binary files a/egs/tedlium3/ASR/.lora.sh.swp and b/egs/tedlium3/ASR/.lora.sh.swp differ diff --git a/egs/tedlium3/ASR/lora.sh b/egs/tedlium3/ASR/lora.sh index f27b74630..216ad1905 100755 --- a/egs/tedlium3/ASR/lora.sh +++ b/egs/tedlium3/ASR/lora.sh @@ -63,7 +63,6 @@ else --additional-block True \ --prune-range 10 \ --spk-id $2 \ - --prefix vox touch ./pruned_transducer_stateless_d2v_v2/$1/.train.done fi fi diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/.train.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless/.train.py.swp index ed5985376..59b44523e 100644 Binary files a/egs/tedlium3/ASR/pruned_transducer_stateless/.train.py.swp and b/egs/tedlium3/ASR/pruned_transducer_stateless/.train.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp index f3d14bf04..1fa2e8fbc 100644 Binary files a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp and b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/.train_tta.py.swp differ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py index 143a23f09..0468e016e 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/train_tta.py @@ -1345,23 +1345,19 @@ def run(rank, world_size, args, wb=None): register_inf_check_hooks(model) #librispeech = LibriSpeechAsrDataModule(args) - ted = TedLiumAsrDataModule(args) - - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + tedlium = TedLiumAsrDataModule(args) + train_cuts = tedlium.train_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds + # Keep only utterances with duration between 1 second and 17 seconds # - # Caution: There is a reason to select 20.0 here. Please see + # Caution: There is a reason to select 17.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + return 1.0 <= c.duration <= 17.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1372,15 +1368,12 @@ def run(rank, world_size, args, wb=None): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = tedlium.train_dataloaders(train_cuts) + valid_cuts = tedlium.dev_cuts() + valid_dl = tedlium.valid_dataloaders(valid_cuts) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1537,24 +1530,18 @@ def run_pea(rank, world_size, args, wb=None): scheduler_pea = Eden(optimizer_pea, 10000, 7) optimizer, scheduler = optimizer_pea, scheduler_pea - librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.vox_cuts(option=params.spk_id) + tedlium = TedLiumAsrDataModule(args) + train_cuts = tedlium.user_test_cuts(spk_id=params.spk_id) def remove_short_and_long_utt(c: Cut): return 1.0 <= c.duration <= 20.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) - sampler_state_dict = None + + train_dl = tedlium.train_dataloaders(train_cuts) + valid_dl = None - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = librispeech.dev_clean_cuts(option=params.gender) - valid_cuts += librispeech.dev_other_cuts(option=params.gender) - valid_dl = librispeech.valid_dataloaders(valid_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) for epoch in range(params.start_epoch, params.num_epochs + 1):