mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-28 19:24:17 +00:00
Fix bugs in HubertDataset
This commit is contained in:
parent
8515d92f47
commit
99044e1c2b
@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -40,6 +41,7 @@ class HubertDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
@ -54,6 +56,9 @@ class HubertDataset(torch.utils.data.Dataset):
|
||||
self.pad_audio = pad_audio
|
||||
self.num_classes = num_classes
|
||||
self.normalize = do_normalize
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
self._validate(cuts)
|
||||
@ -61,10 +66,11 @@ class HubertDataset(torch.utils.data.Dataset):
|
||||
for i, item in enumerate(audio):
|
||||
audio[i] = self.postprocess(item, self.sample_rate)
|
||||
audio_lens = [cut.num_samples for cut in cuts]
|
||||
|
||||
if self.pad_audio:
|
||||
audio_size = max(audio_lens)
|
||||
audio_size = min(max(audio_lens), self.max_sample_size)
|
||||
else:
|
||||
audio_size = min(audio_lens)
|
||||
audio_size = min(min(audio_lens), self.max_sample_size)
|
||||
|
||||
audio, padding_mask, audio_starts = self.collater_audio(
|
||||
audio, audio_lens, audio_size
|
||||
@ -203,6 +209,7 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = True,
|
||||
@ -213,6 +220,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
||||
self.random_crop = random_crop
|
||||
self.pad_audio = pad_audio
|
||||
self.normalize = do_normalize
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
self._validate(cuts)
|
||||
@ -221,9 +231,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
||||
audio[i] = self.postprocess(item, self.sample_rate)
|
||||
audio_lens = [cut.num_samples for cut in cuts]
|
||||
if self.pad_audio:
|
||||
audio_size = max(audio_lens)
|
||||
audio_size = min(max(audio_lens), self.max_sample_size)
|
||||
else:
|
||||
audio_size = min(audio_lens)
|
||||
audio_size = min(min(audio_lens), self.max_sample_size)
|
||||
|
||||
audio, padding_mask, audio_starts = self.collater_audio(
|
||||
audio, audio_lens, audio_size
|
||||
|
@ -1015,10 +1015,8 @@ def main():
|
||||
do_normalize=params.do_normalize,
|
||||
)
|
||||
|
||||
test_sets = ["dev-clean", "dev-other"]
|
||||
test_dl = [dev_clean_dl, dev_other_dl]
|
||||
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
||||
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
||||
test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
||||
test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
|
@ -1015,10 +1015,8 @@ def main():
|
||||
do_normalize=params.do_normalize,
|
||||
)
|
||||
|
||||
test_sets = ["dev-clean", "dev-other"]
|
||||
test_dl = [dev_clean_dl, dev_other_dl]
|
||||
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
||||
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
||||
test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
||||
test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
|
@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -292,17 +293,24 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sample-size",
|
||||
type=float,
|
||||
default=250000,
|
||||
help="max sample size",
|
||||
"--max-keep-size",
|
||||
type=int,
|
||||
default=sys.maxsize,
|
||||
help="exclude sample longer than this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-sample-size",
|
||||
"--min-keep-size",
|
||||
type=float,
|
||||
default=32000,
|
||||
help="min sample size",
|
||||
help="exclude sample longer less than this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sample-size",
|
||||
type=float,
|
||||
default=250000,
|
||||
help="max sample size to crop to for batching.",
|
||||
)
|
||||
|
||||
add_hubert_arguments(parser)
|
||||
@ -884,12 +892,12 @@ def run(rank, world_size, args):
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if (
|
||||
c.duration < params.min_sample_size / params.sample_rate
|
||||
or c.duration > params.max_sample_size / params.sample_rate
|
||||
c.duration < params.min_keep_size / params.sample_rate
|
||||
or c.duration > params.max_keep_size / params.sample_rate
|
||||
):
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -905,6 +913,7 @@ def run(rank, world_size, args):
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts,
|
||||
max_sample_size=params.max_sample_size,
|
||||
sample_rate=params.sample_rate,
|
||||
label_rate=params.label_rate,
|
||||
random_crop=params.random_crop,
|
||||
@ -920,6 +929,7 @@ def run(rank, world_size, args):
|
||||
|
||||
valid_dl = librispeech.valid_dataloaders(
|
||||
valid_cuts,
|
||||
max_sample_size=params.max_sample_size,
|
||||
sample_rate=params.sample_rate,
|
||||
label_rate=params.label_rate,
|
||||
random_crop=params.random_crop,
|
||||
|
@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -292,17 +293,24 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sample-size",
|
||||
type=float,
|
||||
default=250000,
|
||||
help="max sample size",
|
||||
"--max-keep-size",
|
||||
type=int,
|
||||
default=sys.maxsize,
|
||||
help="exclude sample longer than this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-sample-size",
|
||||
"--min-keep-size",
|
||||
type=float,
|
||||
default=32000,
|
||||
help="min sample size",
|
||||
help="exclude sample longer less than this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sample-size",
|
||||
type=float,
|
||||
default=250000,
|
||||
help="max sample size to crop to for batching.",
|
||||
)
|
||||
|
||||
add_hubert_arguments(parser)
|
||||
@ -884,12 +892,12 @@ def run(rank, world_size, args):
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if (
|
||||
c.duration < params.min_sample_size / params.sample_rate
|
||||
or c.duration > params.max_sample_size / params.sample_rate
|
||||
c.duration < params.min_keep_size / params.sample_rate
|
||||
or c.duration > params.max_keep_size / params.sample_rate
|
||||
):
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -905,6 +913,7 @@ def run(rank, world_size, args):
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts,
|
||||
max_sample_size=params.max_sample_size,
|
||||
sample_rate=params.sample_rate,
|
||||
label_rate=params.label_rate,
|
||||
random_crop=params.random_crop,
|
||||
@ -920,6 +929,7 @@ def run(rank, world_size, args):
|
||||
|
||||
valid_dl = librispeech.valid_dataloaders(
|
||||
valid_cuts,
|
||||
max_sample_size=params.max_sample_size,
|
||||
sample_rate=params.sample_rate,
|
||||
label_rate=params.label_rate,
|
||||
random_crop=params.random_crop,
|
||||
|
@ -136,6 +136,7 @@ class LibriSpeechDataModule:
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
@ -153,6 +154,7 @@ class LibriSpeechDataModule:
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = HubertDataset(
|
||||
max_sample_size=max_sample_size,
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
@ -202,6 +204,7 @@ class LibriSpeechDataModule:
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
@ -211,6 +214,7 @@ class LibriSpeechDataModule:
|
||||
) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertDataset(
|
||||
max_sample_size=max_sample_size,
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
|
@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -592,17 +593,24 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sample-size",
|
||||
type=float,
|
||||
default=250000,
|
||||
help="max sample size",
|
||||
"--max-keep-size",
|
||||
type=int,
|
||||
default=sys.maxsize,
|
||||
help="exclude sample longer than this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-sample-size",
|
||||
"--min-keep-size",
|
||||
type=float,
|
||||
default=32000,
|
||||
help="min sample size",
|
||||
help="exclude sample longer less than this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sample-size",
|
||||
type=float,
|
||||
default=250000,
|
||||
help="max sample size to crop to for batching.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -1182,12 +1190,12 @@ def run(rank, world_size, args):
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if (
|
||||
c.duration < params.min_sample_size / params.sample_rate
|
||||
or c.duration > params.max_sample_size / params.sample_rate
|
||||
c.duration < params.min_keep_size / params.sample_rate
|
||||
or c.duration > params.max_keep_size / params.sample_rate
|
||||
):
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -1203,6 +1211,7 @@ def run(rank, world_size, args):
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts,
|
||||
max_sample_size=params.max_sample_size,
|
||||
sample_rate=params.sample_rate,
|
||||
label_rate=params.label_rate,
|
||||
random_crop=params.random_crop,
|
||||
@ -1218,6 +1227,7 @@ def run(rank, world_size, args):
|
||||
|
||||
valid_dl = librispeech.valid_dataloaders(
|
||||
valid_cuts,
|
||||
max_sample_size=params.max_sample_size,
|
||||
sample_rate=params.sample_rate,
|
||||
label_rate=params.label_rate,
|
||||
random_crop=params.random_crop,
|
||||
|
Loading…
x
Reference in New Issue
Block a user