diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index 21128318b..d54157709 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -604,21 +604,18 @@ def run(rank, world_size, args): train_cuts = aishell.train_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - return 1.0 <= c.duration <= 20.0 - - num_in_total = len(train_cuts) + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 12.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = aishell.train_dataloaders(train_cuts) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py index 0975f309a..962fffdf5 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py @@ -640,7 +640,7 @@ def train_one_epoch( def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds + # Keep only utterances with duration between 1 second and 12 seconds # # Caution: There is a reason to select 12.0 here. Please see # ../local/display_manifest_statistics.py diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index dcbc874a0..d3ffccafa 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -630,20 +630,17 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 12.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = aishell.train_dataloaders(train_cuts) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())