From 723320e0159d65856d3c86ae4f75b9d44034fc3d Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 30 Dec 2022 16:06:28 +0800 Subject: [PATCH] add length filter condition --- .../ASR/pruned_transducer_stateless7/train.py | 14 +++++++++++++ .../pruned_transducer_stateless7_ctc/train.py | 14 +++++++++++++ .../train.py | 16 ++++++++++++++- .../train.py | 14 +++++++++++++ .../ASR/pruned_transducer_stateless8/train.py | 20 +++++++++++++++++-- 5 files changed, 75 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 31a3a0505..06ced0534 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1064,6 +1064,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 5a05e1836..ecf2ac698 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1112,6 +1112,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 522ecc974..2f37229fa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -55,9 +55,9 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder +from frame_reducer import FrameReducer from joiner import Joiner from lconv import LConv -from frame_reducer import FrameReducer from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed @@ -1103,6 +1103,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 2bdc882a5..901d5fd37 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1089,6 +1089,20 @@ def run(rank, world_size, args): ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, params.zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index abe249c7b..60c19c5de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1017,7 +1017,7 @@ def train_one_epoch( def filter_short_and_long_utterances( - cuts: CutSet, sp: spm.SentencePieceProcessor + cuts: CutSet, sp: spm.SentencePieceProcessor, zipformer_downsampling_factors: str ) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1054,6 +1054,20 @@ def filter_short_and_long_utterances( ) return False + # Zipformer has DownsampledZipformerEncoders with different downsampling factors + # after encoder_embed that does T -> (T - 7) // 2 + ds = tuple(map(int, zipformer_downsampling_factors.split(","))) + max_ds = max(ds) + T = (c.num_frames - 7) // 2 + if T < max_ds: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before encoder_embed): {c.num_frames}. " + f"Number of frames (after encoder_embed): {T}. " + f"Max downsampling factor in Zipformer: {max_ds}. " + ) + return False + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1159,7 +1173,9 @@ def run(rank, world_size, args): train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts, sp) + train_cuts = filter_short_and_long_utterances( + train_cuts, sp, params.zipformer_downsampling_factors + ) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours