from local

This commit is contained in:
dohe0342 2023-06-09 16:44:43 +09:00
parent adba1328e9
commit ca85302fa0
5 changed files with 14 additions and 28 deletions

Binary file not shown.

View File

@ -63,7 +63,6 @@ else
--additional-block True \ --additional-block True \
--prune-range 10 \ --prune-range 10 \
--spk-id $2 \ --spk-id $2 \
--prefix vox
touch ./pruned_transducer_stateless_d2v_v2/$1/.train.done touch ./pruned_transducer_stateless_d2v_v2/$1/.train.done
fi fi
fi fi

View File

@ -1345,23 +1345,19 @@ def run(rank, world_size, args, wb=None):
register_inf_check_hooks(model) register_inf_check_hooks(model)
#librispeech = LibriSpeechAsrDataModule(args) #librispeech = LibriSpeechAsrDataModule(args)
ted = TedLiumAsrDataModule(args) tedlium = TedLiumAsrDataModule(args)
train_cuts = tedlium.train_cuts()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 17 seconds
# #
# Caution: There is a reason to select 20.0 here. Please see # Caution: There is a reason to select 17.0 here. Please see
# ../local/display_manifest_statistics.py # ../local/display_manifest_statistics.py
# #
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 17.0
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1372,15 +1368,12 @@ def run(rank, world_size, args, wb=None):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = librispeech.train_dataloaders( train_dl = tedlium.train_dataloaders(train_cuts)
train_cuts, sampler_state_dict=sampler_state_dict valid_cuts = tedlium.dev_cuts()
) valid_dl = tedlium.valid_dataloaders(valid_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1537,23 +1530,17 @@ def run_pea(rank, world_size, args, wb=None):
scheduler_pea = Eden(optimizer_pea, 10000, 7) scheduler_pea = Eden(optimizer_pea, 10000, 7)
optimizer, scheduler = optimizer_pea, scheduler_pea optimizer, scheduler = optimizer_pea, scheduler_pea
librispeech = LibriSpeechAsrDataModule(args) tedlium = TedLiumAsrDataModule(args)
train_cuts = librispeech.vox_cuts(option=params.spk_id) train_cuts = tedlium.user_test_cuts(spk_id=params.spk_id)
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None sampler_state_dict = None
train_dl = librispeech.train_dataloaders( train_dl = tedlium.train_dataloaders(train_cuts)
train_cuts, sampler_state_dict=sampler_state_dict valid_dl = None
)
valid_cuts = librispeech.dev_clean_cuts(option=params.gender)
valid_cuts += librispeech.dev_other_cuts(option=params.gender)
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)