from local

This commit is contained in:
dohe0342 2023-02-14 00:51:47 +09:00
parent 32386cafd7
commit 274e9de876
2 changed files with 18 additions and 9 deletions

View File

@ -935,12 +935,14 @@ def run(rank, world_size, args):
if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model)
librispeech = LibriSpeechAsrDataModule(args)
#librispeech = LibriSpeechAsrDataModule(args)
tedlium = TedAsrDataModule(args)
if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
#if params.full_libri:
# train_cuts = librispeech.train_all_shuf_cuts()
#else:
# train_cuts = librispeech.train_clean_100_cuts()
train_cuts = tedlium.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -975,13 +977,20 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
#train_dl = librispeech.train_dataloaders(
# train_cuts, sampler_state_dict=sampler_state_dict
#)
#valid_cuts = librispeech.dev_clean_cuts()
#valid_cuts += librispeech.dev_other_cuts()
#valid_dl = librispeech.valid_dataloaders(valid_cuts)
train_dl = tedlium.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
valid_cuts = tedlium.dev_cuts()
valid_dl = tedlium.valid_dataloaders(valid_cuts)
if params.print_diagnostics:
scan_pessimistic_batches_for_oom(