From 5b50ffda54ac45ad481bbfd525f4c07880779eb8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 May 2023 15:10:06 +0800 Subject: [PATCH] support using mini librispeech in training (#1048) * support mini librispeech in training * update onnx export doc --- docs/source/model-export/export-onnx.rst | 9 +++++++ .../ASR/pruned_transducer_stateless7/train.py | 13 ++++++---- .../train.py | 18 +++++++++----- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 24 ++++++++++++++++++- 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index aa77204cb..fb952abb7 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -3,6 +3,15 @@ Export to ONNX In this section, we describe how to export models to `ONNX`_. +.. hint:: + + Before you continue, please run: + + .. code-block:: bash + + pip install onnx + + In each recipe, there is a file called ``export-onnx.py``, which is used to export trained models to `ONNX`_. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 1b179ceff..ed6dfc28f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -60,13 +60,13 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from multidataset import MultiDataset from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer +from multidataset import MultiDataset from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler @@ -1056,7 +1056,9 @@ def run(rank, world_size, args): multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir) train_cuts = multidataset.train_cuts() else: - if params.full_libri: + if params.mini_libri: + train_cuts = librispeech.train_clean_5_cuts() + elif params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() else: train_cuts = librispeech.train_clean_100_cuts() @@ -1108,8 +1110,11 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() + if params.mini_libri: + valid_cuts = librispeech.dev_clean_2_cuts() + else: + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) if not params.use_multidataset and not params.print_diagnostics: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index b2f9ffc09..90428133d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1049,10 +1049,13 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + if params.mini_libri: + train_cuts = librispeech.train_clean_5_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1104,8 +1107,11 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() + if params.mini_libri: + valid_cuts = librispeech.dev_clean_2_cuts() + else: + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) # if not params.print_diagnostics: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c5787835d..c47964b07 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule: "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + group.add_argument( "--manifest-dir", type=Path, @@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule: ) return test_dl + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") @@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule: self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" ) + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + @lru_cache() def dev_clean_cuts(self) -> CutSet: logging.info("About to get dev-clean cuts")