mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-28 19:24:17 +00:00
Fix bugs in HubertDataset
This commit is contained in:
parent
8515d92f47
commit
99044e1c2b
@ -14,7 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict
|
import sys
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -40,6 +41,7 @@ class HubertDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
max_sample_size: Optional[int] = None,
|
||||||
sample_rate: float = 16000,
|
sample_rate: float = 16000,
|
||||||
label_rate: float = 50,
|
label_rate: float = 50,
|
||||||
random_crop: bool = True,
|
random_crop: bool = True,
|
||||||
@ -54,6 +56,9 @@ class HubertDataset(torch.utils.data.Dataset):
|
|||||||
self.pad_audio = pad_audio
|
self.pad_audio = pad_audio
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.normalize = do_normalize
|
self.normalize = do_normalize
|
||||||
|
self.max_sample_size = (
|
||||||
|
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||||
|
)
|
||||||
|
|
||||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||||
self._validate(cuts)
|
self._validate(cuts)
|
||||||
@ -61,10 +66,11 @@ class HubertDataset(torch.utils.data.Dataset):
|
|||||||
for i, item in enumerate(audio):
|
for i, item in enumerate(audio):
|
||||||
audio[i] = self.postprocess(item, self.sample_rate)
|
audio[i] = self.postprocess(item, self.sample_rate)
|
||||||
audio_lens = [cut.num_samples for cut in cuts]
|
audio_lens = [cut.num_samples for cut in cuts]
|
||||||
|
|
||||||
if self.pad_audio:
|
if self.pad_audio:
|
||||||
audio_size = max(audio_lens)
|
audio_size = min(max(audio_lens), self.max_sample_size)
|
||||||
else:
|
else:
|
||||||
audio_size = min(audio_lens)
|
audio_size = min(min(audio_lens), self.max_sample_size)
|
||||||
|
|
||||||
audio, padding_mask, audio_starts = self.collater_audio(
|
audio, padding_mask, audio_starts = self.collater_audio(
|
||||||
audio, audio_lens, audio_size
|
audio, audio_lens, audio_size
|
||||||
@ -203,6 +209,7 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
max_sample_size: Optional[int] = None,
|
||||||
sample_rate: float = 16000,
|
sample_rate: float = 16000,
|
||||||
random_crop: bool = True,
|
random_crop: bool = True,
|
||||||
pad_audio: bool = True,
|
pad_audio: bool = True,
|
||||||
@ -213,6 +220,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
|||||||
self.random_crop = random_crop
|
self.random_crop = random_crop
|
||||||
self.pad_audio = pad_audio
|
self.pad_audio = pad_audio
|
||||||
self.normalize = do_normalize
|
self.normalize = do_normalize
|
||||||
|
self.max_sample_size = (
|
||||||
|
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||||
|
)
|
||||||
|
|
||||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||||
self._validate(cuts)
|
self._validate(cuts)
|
||||||
@ -221,9 +231,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
|||||||
audio[i] = self.postprocess(item, self.sample_rate)
|
audio[i] = self.postprocess(item, self.sample_rate)
|
||||||
audio_lens = [cut.num_samples for cut in cuts]
|
audio_lens = [cut.num_samples for cut in cuts]
|
||||||
if self.pad_audio:
|
if self.pad_audio:
|
||||||
audio_size = max(audio_lens)
|
audio_size = min(max(audio_lens), self.max_sample_size)
|
||||||
else:
|
else:
|
||||||
audio_size = min(audio_lens)
|
audio_size = min(min(audio_lens), self.max_sample_size)
|
||||||
|
|
||||||
audio, padding_mask, audio_starts = self.collater_audio(
|
audio, padding_mask, audio_starts = self.collater_audio(
|
||||||
audio, audio_lens, audio_size
|
audio, audio_lens, audio_size
|
||||||
|
@ -1015,10 +1015,8 @@ def main():
|
|||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_sets = ["dev-clean", "dev-other"]
|
test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
||||||
test_dl = [dev_clean_dl, dev_other_dl]
|
test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
||||||
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
|
||||||
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
@ -1015,10 +1015,8 @@ def main():
|
|||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_sets = ["dev-clean", "dev-other"]
|
test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
||||||
test_dl = [dev_clean_dl, dev_other_dl]
|
test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
||||||
# test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
|
|
||||||
# test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -292,17 +293,24 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sample-size",
|
"--max-keep-size",
|
||||||
type=float,
|
type=int,
|
||||||
default=250000,
|
default=sys.maxsize,
|
||||||
help="max sample size",
|
help="exclude sample longer than this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-sample-size",
|
"--min-keep-size",
|
||||||
type=float,
|
type=float,
|
||||||
default=32000,
|
default=32000,
|
||||||
help="min sample size",
|
help="exclude sample longer less than this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sample-size",
|
||||||
|
type=float,
|
||||||
|
default=250000,
|
||||||
|
help="max sample size to crop to for batching.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_hubert_arguments(parser)
|
add_hubert_arguments(parser)
|
||||||
@ -884,12 +892,12 @@ def run(rank, world_size, args):
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if (
|
if (
|
||||||
c.duration < params.min_sample_size / params.sample_rate
|
c.duration < params.min_keep_size / params.sample_rate
|
||||||
or c.duration > params.max_sample_size / params.sample_rate
|
or c.duration > params.max_keep_size / params.sample_rate
|
||||||
):
|
):
|
||||||
# logging.warning(
|
logging.warning(
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -905,6 +913,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(
|
train_dl = librispeech.train_dataloaders(
|
||||||
train_cuts,
|
train_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
@ -920,6 +929,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
|
@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -292,17 +293,24 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sample-size",
|
"--max-keep-size",
|
||||||
type=float,
|
type=int,
|
||||||
default=250000,
|
default=sys.maxsize,
|
||||||
help="max sample size",
|
help="exclude sample longer than this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-sample-size",
|
"--min-keep-size",
|
||||||
type=float,
|
type=float,
|
||||||
default=32000,
|
default=32000,
|
||||||
help="min sample size",
|
help="exclude sample longer less than this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sample-size",
|
||||||
|
type=float,
|
||||||
|
default=250000,
|
||||||
|
help="max sample size to crop to for batching.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_hubert_arguments(parser)
|
add_hubert_arguments(parser)
|
||||||
@ -884,12 +892,12 @@ def run(rank, world_size, args):
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if (
|
if (
|
||||||
c.duration < params.min_sample_size / params.sample_rate
|
c.duration < params.min_keep_size / params.sample_rate
|
||||||
or c.duration > params.max_sample_size / params.sample_rate
|
or c.duration > params.max_keep_size / params.sample_rate
|
||||||
):
|
):
|
||||||
# logging.warning(
|
logging.warning(
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -905,6 +913,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(
|
train_dl = librispeech.train_dataloaders(
|
||||||
train_cuts,
|
train_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
@ -920,6 +929,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
|
@ -136,6 +136,7 @@ class LibriSpeechDataModule:
|
|||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
|
max_sample_size: Optional[int] = None,
|
||||||
sample_rate: float = 16000,
|
sample_rate: float = 16000,
|
||||||
label_rate: float = 50,
|
label_rate: float = 50,
|
||||||
random_crop: bool = True,
|
random_crop: bool = True,
|
||||||
@ -153,6 +154,7 @@ class LibriSpeechDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = HubertDataset(
|
train = HubertDataset(
|
||||||
|
max_sample_size=max_sample_size,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
label_rate=label_rate,
|
label_rate=label_rate,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
@ -202,6 +204,7 @@ class LibriSpeechDataModule:
|
|||||||
def valid_dataloaders(
|
def valid_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_valid: CutSet,
|
cuts_valid: CutSet,
|
||||||
|
max_sample_size: Optional[int] = None,
|
||||||
sample_rate: float = 16000,
|
sample_rate: float = 16000,
|
||||||
label_rate: float = 50,
|
label_rate: float = 50,
|
||||||
random_crop: bool = True,
|
random_crop: bool = True,
|
||||||
@ -211,6 +214,7 @@ class LibriSpeechDataModule:
|
|||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
validate = HubertDataset(
|
validate = HubertDataset(
|
||||||
|
max_sample_size=max_sample_size,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
label_rate=label_rate,
|
label_rate=label_rate,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
|
@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -592,17 +593,24 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sample-size",
|
"--max-keep-size",
|
||||||
type=float,
|
type=int,
|
||||||
default=250000,
|
default=sys.maxsize,
|
||||||
help="max sample size",
|
help="exclude sample longer than this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-sample-size",
|
"--min-keep-size",
|
||||||
type=float,
|
type=float,
|
||||||
default=32000,
|
default=32000,
|
||||||
help="min sample size",
|
help="exclude sample longer less than this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sample-size",
|
||||||
|
type=float,
|
||||||
|
default=250000,
|
||||||
|
help="max sample size to crop to for batching.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -1182,12 +1190,12 @@ def run(rank, world_size, args):
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if (
|
if (
|
||||||
c.duration < params.min_sample_size / params.sample_rate
|
c.duration < params.min_keep_size / params.sample_rate
|
||||||
or c.duration > params.max_sample_size / params.sample_rate
|
or c.duration > params.max_keep_size / params.sample_rate
|
||||||
):
|
):
|
||||||
# logging.warning(
|
logging.warning(
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -1203,6 +1211,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(
|
train_dl = librispeech.train_dataloaders(
|
||||||
train_cuts,
|
train_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
@ -1218,6 +1227,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user