Fix bugs in HubertDataset

This commit is contained in:
yifanyeung 2024-02-27 22:13:39 +08:00
parent 8515d92f47
commit 99044e1c2b
7 changed files with 86 additions and 46 deletions

View File

@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict import sys
from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch import torch
@ -40,6 +41,7 @@ class HubertDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000, sample_rate: float = 16000,
label_rate: float = 50, label_rate: float = 50,
random_crop: bool = True, random_crop: bool = True,
@ -54,6 +56,9 @@ class HubertDataset(torch.utils.data.Dataset):
self.pad_audio = pad_audio self.pad_audio = pad_audio
self.num_classes = num_classes self.num_classes = num_classes
self.normalize = do_normalize self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts) self._validate(cuts)
@ -61,10 +66,11 @@ class HubertDataset(torch.utils.data.Dataset):
for i, item in enumerate(audio): for i, item in enumerate(audio):
audio[i] = self.postprocess(item, self.sample_rate) audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts] audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio: if self.pad_audio:
audio_size = max(audio_lens) audio_size = min(max(audio_lens), self.max_sample_size)
else: else:
audio_size = min(audio_lens) audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio( audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size audio, audio_lens, audio_size
@ -203,6 +209,7 @@ class HubertAsrDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000, sample_rate: float = 16000,
random_crop: bool = True, random_crop: bool = True,
pad_audio: bool = True, pad_audio: bool = True,
@ -213,6 +220,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
self.random_crop = random_crop self.random_crop = random_crop
self.pad_audio = pad_audio self.pad_audio = pad_audio
self.normalize = do_normalize self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts) self._validate(cuts)
@ -221,9 +231,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
audio[i] = self.postprocess(item, self.sample_rate) audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts] audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio: if self.pad_audio:
audio_size = max(audio_lens) audio_size = min(max(audio_lens), self.max_sample_size)
else: else:
audio_size = min(audio_lens) audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio( audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size audio, audio_lens, audio_size

View File

@ -1015,10 +1015,8 @@ def main():
do_normalize=params.do_normalize, do_normalize=params.do_normalize,
) )
test_sets = ["dev-clean", "dev-other"] test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
test_dl = [dev_clean_dl, dev_other_dl] test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -1015,10 +1015,8 @@ def main():
do_normalize=params.do_normalize, do_normalize=params.do_normalize,
) )
test_sets = ["dev-clean", "dev-other"] test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
test_dl = [dev_clean_dl, dev_other_dl] test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse import argparse
import copy import copy
import logging import logging
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -292,17 +293,24 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--max-sample-size", "--max-keep-size",
type=float, type=int,
default=250000, default=sys.maxsize,
help="max sample size", help="exclude sample longer than this.",
) )
parser.add_argument( parser.add_argument(
"--min-sample-size", "--min-keep-size",
type=float, type=float,
default=32000, default=32000,
help="min sample size", help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size to crop to for batching.",
) )
add_hubert_arguments(parser) add_hubert_arguments(parser)
@ -884,12 +892,12 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if ( if (
c.duration < params.min_sample_size / params.sample_rate c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate or c.duration > params.max_keep_size / params.sample_rate
): ):
# logging.warning( logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) )
return False return False
return True return True
@ -905,6 +913,7 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders( train_dl = librispeech.train_dataloaders(
train_cuts, train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate, sample_rate=params.sample_rate,
label_rate=params.label_rate, label_rate=params.label_rate,
random_crop=params.random_crop, random_crop=params.random_crop,
@ -920,6 +929,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders( valid_dl = librispeech.valid_dataloaders(
valid_cuts, valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate, sample_rate=params.sample_rate,
label_rate=params.label_rate, label_rate=params.label_rate,
random_crop=params.random_crop, random_crop=params.random_crop,

View File

@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse import argparse
import copy import copy
import logging import logging
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -292,17 +293,24 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--max-sample-size", "--max-keep-size",
type=float, type=int,
default=250000, default=sys.maxsize,
help="max sample size", help="exclude sample longer than this.",
) )
parser.add_argument( parser.add_argument(
"--min-sample-size", "--min-keep-size",
type=float, type=float,
default=32000, default=32000,
help="min sample size", help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size to crop to for batching.",
) )
add_hubert_arguments(parser) add_hubert_arguments(parser)
@ -884,12 +892,12 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if ( if (
c.duration < params.min_sample_size / params.sample_rate c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate or c.duration > params.max_keep_size / params.sample_rate
): ):
# logging.warning( logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) )
return False return False
return True return True
@ -905,6 +913,7 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders( train_dl = librispeech.train_dataloaders(
train_cuts, train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate, sample_rate=params.sample_rate,
label_rate=params.label_rate, label_rate=params.label_rate,
random_crop=params.random_crop, random_crop=params.random_crop,
@ -920,6 +929,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders( valid_dl = librispeech.valid_dataloaders(
valid_cuts, valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate, sample_rate=params.sample_rate,
label_rate=params.label_rate, label_rate=params.label_rate,
random_crop=params.random_crop, random_crop=params.random_crop,

View File

@ -136,6 +136,7 @@ class LibriSpeechDataModule:
def train_dataloaders( def train_dataloaders(
self, self,
cuts_train: CutSet, cuts_train: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000, sample_rate: float = 16000,
label_rate: float = 50, label_rate: float = 50,
random_crop: bool = True, random_crop: bool = True,
@ -153,6 +154,7 @@ class LibriSpeechDataModule:
""" """
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = HubertDataset( train = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate, sample_rate=sample_rate,
label_rate=label_rate, label_rate=label_rate,
random_crop=random_crop, random_crop=random_crop,
@ -202,6 +204,7 @@ class LibriSpeechDataModule:
def valid_dataloaders( def valid_dataloaders(
self, self,
cuts_valid: CutSet, cuts_valid: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000, sample_rate: float = 16000,
label_rate: float = 50, label_rate: float = 50,
random_crop: bool = True, random_crop: bool = True,
@ -211,6 +214,7 @@ class LibriSpeechDataModule:
) -> DataLoader: ) -> DataLoader:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
validate = HubertDataset( validate = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate, sample_rate=sample_rate,
label_rate=label_rate, label_rate=label_rate,
random_crop=random_crop, random_crop=random_crop,

View File

@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse import argparse
import copy import copy
import logging import logging
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -592,17 +593,24 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--max-sample-size", "--max-keep-size",
type=float, type=int,
default=250000, default=sys.maxsize,
help="max sample size", help="exclude sample longer than this.",
) )
parser.add_argument( parser.add_argument(
"--min-sample-size", "--min-keep-size",
type=float, type=float,
default=32000, default=32000,
help="min sample size", help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size to crop to for batching.",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -1182,12 +1190,12 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if ( if (
c.duration < params.min_sample_size / params.sample_rate c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate or c.duration > params.max_keep_size / params.sample_rate
): ):
# logging.warning( logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) )
return False return False
return True return True
@ -1203,6 +1211,7 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders( train_dl = librispeech.train_dataloaders(
train_cuts, train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate, sample_rate=params.sample_rate,
label_rate=params.label_rate, label_rate=params.label_rate,
random_crop=params.random_crop, random_crop=params.random_crop,
@ -1218,6 +1227,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders( valid_dl = librispeech.valid_dataloaders(
valid_cuts, valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate, sample_rate=params.sample_rate,
label_rate=params.label_rate, label_rate=params.label_rate,
random_crop=params.random_crop, random_crop=params.random_crop,