diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 436ec53b4..48ed07d47 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1075,6 +1075,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b35e56abc..07df09260 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1100,6 +1100,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index c2d877a93..fa07298b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1090,6 +1090,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 8bd00bbef..886e97f52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1093,6 +1093,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True # train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 646f30ca1..1193510f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1016,7 +1016,7 @@ def train_one_epoch( def filter_short_and_long_utterances( - cuts: CutSet, sp: spm.SentencePieceProcessor + cuts: CutSet, sp: spm.SentencePieceProcessor, zipformer_downsampling_factors: str ) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1053,6 +1053,20 @@ def filter_short_and_long_utterances( ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1167,7 +1181,9 @@ def run(rank, world_size, args): else: train_cuts = librispeech.train_clean_100_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts, sp) + train_cuts = filter_short_and_long_utterances( + train_cuts, sp, params.zipformer_downsampling_factors + ) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours