mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
support using mini librispeech in training (#1048)
* support mini librispeech in training * update onnx export doc
This commit is contained in:
parent
ebbab37776
commit
5b50ffda54
@ -3,6 +3,15 @@ Export to ONNX
|
||||
|
||||
In this section, we describe how to export models to `ONNX`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
Before you continue, please run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install onnx
|
||||
|
||||
|
||||
In each recipe, there is a file called ``export-onnx.py``, which is used
|
||||
to export trained models to `ONNX`_.
|
||||
|
||||
|
@ -60,13 +60,13 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from multidataset import MultiDataset
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from multidataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
@ -1056,7 +1056,9 @@ def run(rank, world_size, args):
|
||||
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
|
||||
train_cuts = multidataset.train_cuts()
|
||||
else:
|
||||
if params.full_libri:
|
||||
if params.mini_libri:
|
||||
train_cuts = librispeech.train_clean_5_cuts()
|
||||
elif params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
@ -1108,8 +1110,11 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
if params.mini_libri:
|
||||
valid_cuts = librispeech.dev_clean_2_cuts()
|
||||
else:
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.use_multidataset and not params.print_diagnostics:
|
||||
|
@ -1049,10 +1049,13 @@ def run(rank, world_size, args):
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
if params.mini_libri:
|
||||
train_cuts = librispeech.train_clean_5_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -1104,8 +1107,11 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
if params.mini_libri:
|
||||
valid_cuts = librispeech.dev_clean_2_cuts()
|
||||
else:
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
|
@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule:
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
|
||||
help="""Used only when --mini-libri is False.When enabled,
|
||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--mini-libri",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True for mini librispeech",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule:
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_5_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule:
|
||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_2_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
|
Loading…
x
Reference in New Issue
Block a user