From 99044e1c2b85e2b61cbae86bb2f1558d5019254d Mon Sep 17 00:00:00 2001 From: yifanyeung Date: Tue, 27 Feb 2024 22:13:39 +0800 Subject: [PATCH] Fix bugs in HubertDataset --- egs/librispeech/SSL/hubert/dataset.py | 20 +++++++++--- egs/librispeech/SSL/hubert/decode.py | 6 ++-- egs/librispeech/SSL/hubert/decode_ce.py | 6 ++-- egs/librispeech/SSL/hubert/pretrain.py | 32 +++++++++++++------- egs/librispeech/SSL/hubert/pretrain_ce.py | 32 +++++++++++++------- egs/librispeech/SSL/hubert/ssl_datamodule.py | 4 +++ egs/librispeech/SSL/zipformer/pretrain.py | 32 +++++++++++++------- 7 files changed, 86 insertions(+), 46 deletions(-) diff --git a/egs/librispeech/SSL/hubert/dataset.py b/egs/librispeech/SSL/hubert/dataset.py index 1ee8f3e32..c41ec8bb7 100644 --- a/egs/librispeech/SSL/hubert/dataset.py +++ b/egs/librispeech/SSL/hubert/dataset.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +import sys +from typing import Any, Dict, Optional import numpy as np import torch @@ -40,6 +41,7 @@ class HubertDataset(torch.utils.data.Dataset): def __init__( self, + max_sample_size: Optional[int] = None, sample_rate: float = 16000, label_rate: float = 50, random_crop: bool = True, @@ -54,6 +56,9 @@ class HubertDataset(torch.utils.data.Dataset): self.pad_audio = pad_audio self.num_classes = num_classes self.normalize = do_normalize + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: self._validate(cuts) @@ -61,10 +66,11 @@ class HubertDataset(torch.utils.data.Dataset): for i, item in enumerate(audio): audio[i] = self.postprocess(item, self.sample_rate) audio_lens = [cut.num_samples for cut in cuts] + if self.pad_audio: - audio_size = max(audio_lens) + audio_size = min(max(audio_lens), self.max_sample_size) else: - audio_size = min(audio_lens) + audio_size = min(min(audio_lens), self.max_sample_size) audio, padding_mask, audio_starts = self.collater_audio( audio, audio_lens, audio_size @@ -203,6 +209,7 @@ class HubertAsrDataset(torch.utils.data.Dataset): def __init__( self, + max_sample_size: Optional[int] = None, sample_rate: float = 16000, random_crop: bool = True, pad_audio: bool = True, @@ -213,6 +220,9 @@ class HubertAsrDataset(torch.utils.data.Dataset): self.random_crop = random_crop self.pad_audio = pad_audio self.normalize = do_normalize + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: self._validate(cuts) @@ -221,9 +231,9 @@ class HubertAsrDataset(torch.utils.data.Dataset): audio[i] = self.postprocess(item, self.sample_rate) audio_lens = [cut.num_samples for cut in cuts] if self.pad_audio: - audio_size = max(audio_lens) + audio_size = min(max(audio_lens), self.max_sample_size) else: - audio_size = min(audio_lens) + audio_size = min(min(audio_lens), self.max_sample_size) audio, padding_mask, audio_starts = self.collater_audio( audio, audio_lens, audio_size diff --git a/egs/librispeech/SSL/hubert/decode.py b/egs/librispeech/SSL/hubert/decode.py index 3b540f233..837061b8c 100644 --- a/egs/librispeech/SSL/hubert/decode.py +++ b/egs/librispeech/SSL/hubert/decode.py @@ -1015,10 +1015,8 @@ def main(): do_normalize=params.do_normalize, ) - test_sets = ["dev-clean", "dev-other"] - test_dl = [dev_clean_dl, dev_other_dl] - # test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] - # test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/librispeech/SSL/hubert/decode_ce.py b/egs/librispeech/SSL/hubert/decode_ce.py index 656b8b0b5..a8d8bc9c2 100644 --- a/egs/librispeech/SSL/hubert/decode_ce.py +++ b/egs/librispeech/SSL/hubert/decode_ce.py @@ -1015,10 +1015,8 @@ def main(): do_normalize=params.do_normalize, ) - test_sets = ["dev-clean", "dev-other"] - test_dl = [dev_clean_dl, dev_other_dl] - # test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] - # test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index 89bc53338..99b68f68d 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse import copy import logging +import sys import warnings from pathlib import Path from shutil import copyfile @@ -292,17 +293,24 @@ def get_parser(): ) parser.add_argument( - "--max-sample-size", - type=float, - default=250000, - help="max sample size", + "--max-keep-size", + type=int, + default=sys.maxsize, + help="exclude sample longer than this.", ) parser.add_argument( - "--min-sample-size", + "--min-keep-size", type=float, default=32000, - help="min sample size", + help="exclude sample longer less than this.", + ) + + parser.add_argument( + "--max-sample-size", + type=float, + default=250000, + help="max sample size to crop to for batching.", ) add_hubert_arguments(parser) @@ -884,12 +892,12 @@ def run(rank, world_size, args): # an utterance duration distribution for your dataset to select # the threshold if ( - c.duration < params.min_sample_size / params.sample_rate - or c.duration > params.max_sample_size / params.sample_rate + c.duration < params.min_keep_size / params.sample_rate + or c.duration > params.max_keep_size / params.sample_rate ): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) return False return True @@ -905,6 +913,7 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders( train_cuts, + max_sample_size=params.max_sample_size, sample_rate=params.sample_rate, label_rate=params.label_rate, random_crop=params.random_crop, @@ -920,6 +929,7 @@ def run(rank, world_size, args): valid_dl = librispeech.valid_dataloaders( valid_cuts, + max_sample_size=params.max_sample_size, sample_rate=params.sample_rate, label_rate=params.label_rate, random_crop=params.random_crop, diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 755be202e..20a1b06c2 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse import copy import logging +import sys import warnings from pathlib import Path from shutil import copyfile @@ -292,17 +293,24 @@ def get_parser(): ) parser.add_argument( - "--max-sample-size", - type=float, - default=250000, - help="max sample size", + "--max-keep-size", + type=int, + default=sys.maxsize, + help="exclude sample longer than this.", ) parser.add_argument( - "--min-sample-size", + "--min-keep-size", type=float, default=32000, - help="min sample size", + help="exclude sample longer less than this.", + ) + + parser.add_argument( + "--max-sample-size", + type=float, + default=250000, + help="max sample size to crop to for batching.", ) add_hubert_arguments(parser) @@ -884,12 +892,12 @@ def run(rank, world_size, args): # an utterance duration distribution for your dataset to select # the threshold if ( - c.duration < params.min_sample_size / params.sample_rate - or c.duration > params.max_sample_size / params.sample_rate + c.duration < params.min_keep_size / params.sample_rate + or c.duration > params.max_keep_size / params.sample_rate ): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) return False return True @@ -905,6 +913,7 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders( train_cuts, + max_sample_size=params.max_sample_size, sample_rate=params.sample_rate, label_rate=params.label_rate, random_crop=params.random_crop, @@ -920,6 +929,7 @@ def run(rank, world_size, args): valid_dl = librispeech.valid_dataloaders( valid_cuts, + max_sample_size=params.max_sample_size, sample_rate=params.sample_rate, label_rate=params.label_rate, random_crop=params.random_crop, diff --git a/egs/librispeech/SSL/hubert/ssl_datamodule.py b/egs/librispeech/SSL/hubert/ssl_datamodule.py index 07e903600..34e18302f 100644 --- a/egs/librispeech/SSL/hubert/ssl_datamodule.py +++ b/egs/librispeech/SSL/hubert/ssl_datamodule.py @@ -136,6 +136,7 @@ class LibriSpeechDataModule: def train_dataloaders( self, cuts_train: CutSet, + max_sample_size: Optional[int] = None, sample_rate: float = 16000, label_rate: float = 50, random_crop: bool = True, @@ -153,6 +154,7 @@ class LibriSpeechDataModule: """ logging.info("About to create train dataset") train = HubertDataset( + max_sample_size=max_sample_size, sample_rate=sample_rate, label_rate=label_rate, random_crop=random_crop, @@ -202,6 +204,7 @@ class LibriSpeechDataModule: def valid_dataloaders( self, cuts_valid: CutSet, + max_sample_size: Optional[int] = None, sample_rate: float = 16000, label_rate: float = 50, random_crop: bool = True, @@ -211,6 +214,7 @@ class LibriSpeechDataModule: ) -> DataLoader: logging.info("About to create dev dataset") validate = HubertDataset( + max_sample_size=max_sample_size, sample_rate=sample_rate, label_rate=label_rate, random_crop=random_crop, diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 1c64c065d..3868b81ad 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse import copy import logging +import sys import warnings from pathlib import Path from shutil import copyfile @@ -592,17 +593,24 @@ def get_parser(): ) parser.add_argument( - "--max-sample-size", - type=float, - default=250000, - help="max sample size", + "--max-keep-size", + type=int, + default=sys.maxsize, + help="exclude sample longer than this.", ) parser.add_argument( - "--min-sample-size", + "--min-keep-size", type=float, default=32000, - help="min sample size", + help="exclude sample longer less than this.", + ) + + parser.add_argument( + "--max-sample-size", + type=float, + default=250000, + help="max sample size to crop to for batching.", ) add_model_arguments(parser) @@ -1182,12 +1190,12 @@ def run(rank, world_size, args): # an utterance duration distribution for your dataset to select # the threshold if ( - c.duration < params.min_sample_size / params.sample_rate - or c.duration > params.max_sample_size / params.sample_rate + c.duration < params.min_keep_size / params.sample_rate + or c.duration > params.max_keep_size / params.sample_rate ): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) return False return True @@ -1203,6 +1211,7 @@ def run(rank, world_size, args): train_dl = librispeech.train_dataloaders( train_cuts, + max_sample_size=params.max_sample_size, sample_rate=params.sample_rate, label_rate=params.label_rate, random_crop=params.random_crop, @@ -1218,6 +1227,7 @@ def run(rank, world_size, args): valid_dl = librispeech.valid_dataloaders( valid_cuts, + max_sample_size=params.max_sample_size, sample_rate=params.sample_rate, label_rate=params.label_rate, random_crop=params.random_crop,