support using mini librispeech in training (#1048)

* support mini librispeech in training

* update onnx export doc
This commit is contained in:
Fangjun Kuang 2023-05-09 15:10:06 +08:00 committed by GitHub
parent ebbab37776
commit 5b50ffda54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 11 deletions

View File

@ -3,6 +3,15 @@ Export to ONNX
In this section, we describe how to export models to `ONNX`_. In this section, we describe how to export models to `ONNX`_.
.. hint::
Before you continue, please run:
.. code-block:: bash
pip install onnx
In each recipe, there is a file called ``export-onnx.py``, which is used In each recipe, there is a file called ``export-onnx.py``, which is used
to export trained models to `ONNX`_. to export trained models to `ONNX`_.

View File

@ -60,13 +60,13 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from multidataset import MultiDataset
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from multidataset import MultiDataset
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
@ -1056,7 +1056,9 @@ def run(rank, world_size, args):
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir) multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
train_cuts = multidataset.train_cuts() train_cuts = multidataset.train_cuts()
else: else:
if params.full_libri: if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts()
elif params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts() train_cuts = librispeech.train_all_shuf_cuts()
else: else:
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
@ -1108,8 +1110,11 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = librispeech.dev_clean_cuts() if params.mini_libri:
valid_cuts += librispeech.dev_other_cuts() valid_cuts = librispeech.dev_clean_2_cuts()
else:
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.use_multidataset and not params.print_diagnostics: if not params.use_multidataset and not params.print_diagnostics:

View File

@ -1049,10 +1049,13 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts() if params.mini_libri:
if params.full_libri: train_cuts = librispeech.train_clean_5_cuts()
train_cuts += librispeech.train_clean_360_cuts() else:
train_cuts += librispeech.train_other_500_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -1104,8 +1107,11 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = librispeech.dev_clean_cuts() if params.mini_libri:
valid_cuts += librispeech.dev_other_cuts() valid_cuts = librispeech.dev_clean_2_cuts()
else:
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics: # if not params.print_diagnostics:

View File

@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule:
"--full-libri", "--full-libri",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""",
) )
group.add_argument(
"--mini-libri",
type=str2bool,
default=False,
help="True for mini librispeech",
)
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
type=Path, type=Path,
@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule:
) )
return test_dl return test_dl
@lru_cache()
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache() @lru_cache()
def train_clean_100_cuts(self) -> CutSet: def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts") logging.info("About to get train-clean-100 cuts")
@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
) )
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
)
@lru_cache() @lru_cache()
def dev_clean_cuts(self) -> CutSet: def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts") logging.info("About to get dev-clean cuts")