mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
support using mini librispeech in training (#1048)
* support mini librispeech in training * update onnx export doc
This commit is contained in:
parent
ebbab37776
commit
5b50ffda54
@ -3,6 +3,15 @@ Export to ONNX
|
|||||||
|
|
||||||
In this section, we describe how to export models to `ONNX`_.
|
In this section, we describe how to export models to `ONNX`_.
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
Before you continue, please run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install onnx
|
||||||
|
|
||||||
|
|
||||||
In each recipe, there is a file called ``export-onnx.py``, which is used
|
In each recipe, there is a file called ``export-onnx.py``, which is used
|
||||||
to export trained models to `ONNX`_.
|
to export trained models to `ONNX`_.
|
||||||
|
|
||||||
|
@ -60,13 +60,13 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from multidataset import MultiDataset
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
from multidataset import MultiDataset
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
@ -1056,7 +1056,9 @@ def run(rank, world_size, args):
|
|||||||
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
|
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
|
||||||
train_cuts = multidataset.train_cuts()
|
train_cuts = multidataset.train_cuts()
|
||||||
else:
|
else:
|
||||||
if params.full_libri:
|
if params.mini_libri:
|
||||||
|
train_cuts = librispeech.train_clean_5_cuts()
|
||||||
|
elif params.full_libri:
|
||||||
train_cuts = librispeech.train_all_shuf_cuts()
|
train_cuts = librispeech.train_all_shuf_cuts()
|
||||||
else:
|
else:
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
@ -1108,8 +1110,11 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
if params.mini_libri:
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts = librispeech.dev_clean_2_cuts()
|
||||||
|
else:
|
||||||
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.use_multidataset and not params.print_diagnostics:
|
if not params.use_multidataset and not params.print_diagnostics:
|
||||||
|
@ -1049,10 +1049,13 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
if params.mini_libri:
|
||||||
if params.full_libri:
|
train_cuts = librispeech.train_clean_5_cuts()
|
||||||
train_cuts += librispeech.train_clean_360_cuts()
|
else:
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
if params.full_libri:
|
||||||
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1104,8 +1107,11 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
if params.mini_libri:
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts = librispeech.dev_clean_2_cuts()
|
||||||
|
else:
|
||||||
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
|
@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule:
|
|||||||
"--full-libri",
|
"--full-libri",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
|
help="""Used only when --mini-libri is False.When enabled,
|
||||||
|
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--mini-libri",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="True for mini librispeech",
|
||||||
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule:
|
|||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_5_cuts(self) -> CutSet:
|
||||||
|
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_100_cuts(self) -> CutSet:
|
def train_clean_100_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-clean-100 cuts")
|
logging.info("About to get train-clean-100 cuts")
|
||||||
@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule:
|
|||||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_clean_2_cuts(self) -> CutSet:
|
||||||
|
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_clean_cuts(self) -> CutSet:
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev-clean cuts")
|
logging.info("About to get dev-clean cuts")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user