diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 43fa0d01b..48b347b64 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -861,15 +861,41 @@ def run(rank, world_size, args): valid_cuts = wenetspeech.valid_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15.0 seconds + # Keep only utterances with duration between 1 second and 10 seconds # - # Caution: There is a reason to select 15.0 here. Please see + # Caution: There is a reason to select 10.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 15.0 + if c.duration < 1.0 or c.duration > 10.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = c.supervisions[0].text.replace(" ", "") + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 440b65f32..34a72be8f 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -1006,15 +1006,41 @@ def run(rank, world_size, args): valid_cuts = wenetspeech.valid_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15.0 seconds + # Keep only utterances with duration between 1 second and 10 seconds # - # Caution: There is a reason to select 15.0 here. Please see + # Caution: There is a reason to select 10.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 15.0 + if c.duration < 1.0 or c.duration > 10.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = c.supervisions[0].text.replace(" ", "") + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt)