From bb07b65e45f91fc9b4f2a53a7f82e97fe52d9164 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 26 Jan 2024 10:18:10 +0800 Subject: [PATCH] add remove long short --- egs/wenetspeech/ASR/whisper/train.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index 42f0735ad..8e55200e1 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -262,7 +262,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 5000, + "valid_interval": 10000, "env_info": get_env_info(), } ) @@ -441,9 +441,6 @@ def compute_loss( assert feature.ndim == 3 feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) - # make sure feature T no more than 3000, otherwise cut it - if feature.shape[2] > 3000: - feature = feature[:, :, :3000] supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) @@ -822,7 +819,23 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = wenetspeech.train_dataloaders(wenetspeech.train_cuts()) + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 15.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_dl = wenetspeech.train_dataloaders(wenetspeech.train_cuts(remove_short_and_long_utt)) valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts()) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)