mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
* Add recipe for the yes_no dataset. * Refactoring: Remove unused code. * Add Colab notebook for the yesno dataset. * Add GitHub actions to run yesno. * Fix a typo. * Minor fixes. * Train more epochs for GitHub actions. * Minor fixes. * Minor fixes. * Fix style issues.
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
# Copyright 2021 Piotr Żelasko
|
|
#
|
|
# See ../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import argparse
|
|
from typing import List, Union
|
|
|
|
from lhotse import CutSet
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class DataModule:
|
|
"""
|
|
Contains dataset-related code. It is intended to read/construct Lhotse cuts,
|
|
and create Dataset/Sampler/DataLoader out of them.
|
|
|
|
There is a separate method to create each of train/valid/test DataLoader.
|
|
In principle, there might be multiple DataLoaders for each of
|
|
train/valid/test
|
|
(e.g. when a corpus has multiple test sets).
|
|
The API of this class allows to return lists of CutSets/DataLoaders.
|
|
"""
|
|
|
|
def __init__(self, args: argparse.Namespace):
|
|
self.args = args
|
|
|
|
@classmethod
|
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
|
pass
|
|
|
|
def train_cuts(self) -> Union[CutSet, List[CutSet]]:
|
|
raise NotImplementedError()
|
|
|
|
def valid_cuts(self) -> Union[CutSet, List[CutSet]]:
|
|
raise NotImplementedError()
|
|
|
|
def test_cuts(self) -> Union[CutSet, List[CutSet]]:
|
|
raise NotImplementedError()
|
|
|
|
def train_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
|
raise NotImplementedError()
|
|
|
|
def valid_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
|
raise NotImplementedError()
|
|
|
|
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
|
raise NotImplementedError()
|