support using mini librispeech in training (#1048)

* support mini librispeech in training

* update onnx export doc
This commit is contained in:
Fangjun Kuang 2023-05-09 15:10:06 +08:00 committed by GitHub
parent ebbab37776
commit 5b50ffda54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 11 deletions

View File

@ -3,6 +3,15 @@ Export to ONNX
In this section, we describe how to export models to `ONNX`_.
.. hint::
Before you continue, please run:
.. code-block:: bash
pip install onnx
In each recipe, there is a file called ``export-onnx.py``, which is used
to export trained models to `ONNX`_.

View File

@ -60,13 +60,13 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from multidataset import MultiDataset
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from multidataset import MultiDataset
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
@ -1056,7 +1056,9 @@ def run(rank, world_size, args):
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
train_cuts = multidataset.train_cuts()
else:
if params.full_libri:
if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts()
elif params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
@ -1108,6 +1110,9 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
if params.mini_libri:
valid_cuts = librispeech.dev_clean_2_cuts()
else:
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

View File

@ -1049,6 +1049,9 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
@ -1104,6 +1107,9 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
if params.mini_libri:
valid_cuts = librispeech.dev_clean_2_cuts()
else:
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

View File

@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule:
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""",
)
group.add_argument(
"--mini-libri",
type=str2bool,
default=False,
help="True for mini librispeech",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule:
)
return test_dl
@lru_cache()
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")