Fix aishell. (#416)

This commit is contained in:
Fangjun Kuang 2022-06-10 11:47:43 +08:00 committed by GitHub
parent dbda1644b5
commit bfeab319c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 23 deletions

View File

@ -604,21 +604,18 @@ def run(rank, world_size, args):
train_cuts = aishell.train_cuts() train_cuts = aishell.train_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 12 seconds
return 1.0 <= c.duration <= 20.0 #
# Caution: There is a reason to select 12.0 here. Please see
num_in_total = len(train_cuts) # ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 12.0
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = aishell.train_dataloaders(train_cuts) train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())

View File

@ -640,7 +640,7 @@ def train_one_epoch(
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 12 seconds
# #
# Caution: There is a reason to select 12.0 here. Please see # Caution: There is a reason to select 12.0 here. Please see
# ../local/display_manifest_statistics.py # ../local/display_manifest_statistics.py

View File

@ -630,20 +630,17 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 12 seconds # Keep only utterances with duration between 1 second and 12 seconds
#
# Caution: There is a reason to select 12.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 12.0 return 1.0 <= c.duration <= 12.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = aishell.train_dataloaders(train_cuts) train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())