from local
This commit is contained in:
parent
6be73866c3
commit
ab16ce6424
Binary file not shown.
@ -1066,9 +1066,9 @@ def run(rank, world_size, args):
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
||||
'''
|
||||
tedlium = TedLiumAsrDataModule(args)
|
||||
|
||||
train_cuts = tedlium.train_cuts()
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
@ -1084,6 +1084,54 @@ def run(rank, world_size, args):
|
||||
|
||||
valid_cuts = tedlium.dev_cuts()
|
||||
valid_dl = tedlium.valid_dataloaders(valid_cuts)
|
||||
'''
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
return 1.0 <= c.duration <= 30.0
|
||||
|
||||
def remove_invalid_utt_ctc(c: Cut):
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
# num_tokens = len(sp.encode(c.supervisions[0].text, out_type=int))
|
||||
num_tokens = len(graph_compiler.texts_to_ids(c.supervisions[0].text))
|
||||
min_output_input_ratio = 0.0005
|
||||
max_output_input_ratio = 0.1
|
||||
return (
|
||||
min_output_input_ratio
|
||||
< num_tokens / float(c.features.num_frames)
|
||||
< max_output_input_ratio
|
||||
)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_cuts = train_cuts.filter(remove_invalid_utt_ctc)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if (
|
||||
params.start_epoch <= 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user