from local

This commit is contained in:
dohe0342 2023-06-09 16:44:43 +09:00
parent adba1328e9
commit ca85302fa0
5 changed files with 14 additions and 28 deletions

Binary file not shown.

View File

@ -63,7 +63,6 @@ else
--additional-block True \
--prune-range 10 \
--spk-id $2 \
--prefix vox
touch ./pruned_transducer_stateless_d2v_v2/$1/.train.done
fi
fi

View File

@ -1345,23 +1345,19 @@ def run(rank, world_size, args, wb=None):
register_inf_check_hooks(model)
#librispeech = LibriSpeechAsrDataModule(args)
ted = TedLiumAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
tedlium = TedLiumAsrDataModule(args)
train_cuts = tedlium.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# Keep only utterances with duration between 1 second and 17 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# Caution: There is a reason to select 17.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
return 1.0 <= c.duration <= 17.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1372,15 +1368,12 @@ def run(rank, world_size, args, wb=None):
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
train_dl = tedlium.train_dataloaders(train_cuts)
valid_cuts = tedlium.dev_cuts()
valid_dl = tedlium.valid_dataloaders(valid_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1537,24 +1530,18 @@ def run_pea(rank, world_size, args, wb=None):
scheduler_pea = Eden(optimizer_pea, 10000, 7)
optimizer, scheduler = optimizer_pea, scheduler_pea
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.vox_cuts(option=params.spk_id)
tedlium = TedLiumAsrDataModule(args)
train_cuts = tedlium.user_test_cuts(spk_id=params.spk_id)
def remove_short_and_long_utt(c: Cut):
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None
train_dl = tedlium.train_dataloaders(train_cuts)
valid_dl = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
for epoch in range(params.start_epoch, params.num_epochs + 1):