diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index f94da9788..0b2075ca3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1076,6 +1076,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index a26f11c82..f4d071469 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1101,6 +1101,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 5585d74de..8a5bbba62 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1091,6 +1091,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 4d8a2644d..749e507a3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1099,6 +1099,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True # train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index ad14ec9dc..cab154289 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1022,7 +1022,7 @@ def train_one_epoch( def filter_short_and_long_utterances( - cuts: CutSet, sp: spm.SentencePieceProcessor + cuts: CutSet, sp: spm.SentencePieceProcessor, zipformer_downsampling_factors: str ) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1059,6 +1059,20 @@ def filter_short_and_long_utterances( ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1173,7 +1187,9 @@ def run(rank, world_size, args): else: train_cuts = librispeech.train_clean_100_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts, sp) + train_cuts = filter_short_and_long_utterances( + train_cuts, sp, params.zipformer_downsampling_factors + ) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours