Fix bugs in HubertDataset

This commit is contained in:
yifanyeung 2024-02-27 22:13:39 +08:00
parent 8515d92f47
commit 99044e1c2b
7 changed files with 86 additions and 46 deletions

View File

@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
import sys
from typing import Any, Dict, Optional
import numpy as np
import torch
@ -40,6 +41,7 @@ class HubertDataset(torch.utils.data.Dataset):
def __init__(
self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
@ -54,6 +56,9 @@ class HubertDataset(torch.utils.data.Dataset):
self.pad_audio = pad_audio
self.num_classes = num_classes
self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
@ -61,10 +66,11 @@ class HubertDataset(torch.utils.data.Dataset):
for i, item in enumerate(audio):
audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio:
audio_size = max(audio_lens)
audio_size = min(max(audio_lens), self.max_sample_size)
else:
audio_size = min(audio_lens)
audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size
@ -203,6 +209,7 @@ class HubertAsrDataset(torch.utils.data.Dataset):
def __init__(
self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
random_crop: bool = True,
pad_audio: bool = True,
@ -213,6 +220,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
self.random_crop = random_crop
self.pad_audio = pad_audio
self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
@ -221,9 +231,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio:
audio_size = max(audio_lens)
audio_size = min(max(audio_lens), self.max_sample_size)
else:
audio_size = min(audio_lens)
audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size

View File

@ -1015,10 +1015,8 @@ def main():
do_normalize=params.do_normalize,
)
test_sets = ["dev-clean", "dev-other"]
test_dl = [dev_clean_dl, dev_other_dl]
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(

View File

@ -1015,10 +1015,8 @@ def main():
do_normalize=params.do_normalize,
)
test_sets = ["dev-clean", "dev-other"]
test_dl = [dev_clean_dl, dev_other_dl]
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(

View File

@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse
import copy
import logging
import sys
import warnings
from pathlib import Path
from shutil import copyfile
@ -292,17 +293,24 @@ def get_parser():
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size",
"--max-keep-size",
type=int,
default=sys.maxsize,
help="exclude sample longer than this.",
)
parser.add_argument(
"--min-sample-size",
"--min-keep-size",
type=float,
default=32000,
help="min sample size",
help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size to crop to for batching.",
)
add_hubert_arguments(parser)
@ -884,12 +892,12 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select
# the threshold
if (
c.duration < params.min_sample_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate
c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_keep_size / params.sample_rate
):
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
@ -905,6 +913,7 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders(
train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,
@ -920,6 +929,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,

View File

@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse
import copy
import logging
import sys
import warnings
from pathlib import Path
from shutil import copyfile
@ -292,17 +293,24 @@ def get_parser():
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size",
"--max-keep-size",
type=int,
default=sys.maxsize,
help="exclude sample longer than this.",
)
parser.add_argument(
"--min-sample-size",
"--min-keep-size",
type=float,
default=32000,
help="min sample size",
help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size to crop to for batching.",
)
add_hubert_arguments(parser)
@ -884,12 +892,12 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select
# the threshold
if (
c.duration < params.min_sample_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate
c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_keep_size / params.sample_rate
):
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
@ -905,6 +913,7 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders(
train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,
@ -920,6 +929,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,

View File

@ -136,6 +136,7 @@ class LibriSpeechDataModule:
def train_dataloaders(
self,
cuts_train: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
@ -153,6 +154,7 @@ class LibriSpeechDataModule:
"""
logging.info("About to create train dataset")
train = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
@ -202,6 +204,7 @@ class LibriSpeechDataModule:
def valid_dataloaders(
self,
cuts_valid: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
@ -211,6 +214,7 @@ class LibriSpeechDataModule:
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,

View File

@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse
import copy
import logging
import sys
import warnings
from pathlib import Path
from shutil import copyfile
@ -592,17 +593,24 @@ def get_parser():
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size",
"--max-keep-size",
type=int,
default=sys.maxsize,
help="exclude sample longer than this.",
)
parser.add_argument(
"--min-sample-size",
"--min-keep-size",
type=float,
default=32000,
help="min sample size",
help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size to crop to for batching.",
)
add_model_arguments(parser)
@ -1182,12 +1190,12 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select
# the threshold
if (
c.duration < params.min_sample_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate
c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_keep_size / params.sample_rate
):
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
@ -1203,6 +1211,7 @@ def run(rank, world_size, args):
train_dl = librispeech.train_dataloaders(
train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,
@ -1218,6 +1227,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,