diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 1c9ec3e89..232d3dd18 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -42,6 +42,11 @@ import argparse import logging from typing import List, Optional, Tuple +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import sentencepiece as spm import torch diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index fb9adb44a..d03d1d7ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -24,6 +24,11 @@ with the given torchscript model for the same input. import argparse import logging +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index c55268b14..66ffbd3ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -21,6 +21,11 @@ This file is to test that models can be exported to onnx. """ import os +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch from conformer import ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 06c4b5204..7716d19cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -21,7 +21,6 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface -from multi_quantization.prediction import JointCodebookLoss from scaling import ScaledLinear from icefall.utils import add_sos @@ -74,6 +73,14 @@ class Transducer(nn.Module): encoder_dim, vocab_size, initial_speed=0.5 ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + from icefall import is_module_available + + if not is_module_available("multi_quantization"): + raise ValueError("Please 'pip install multi_quantization' first.") + + from multi_quantization.prediction import JointCodebookLoss + if num_codebooks > 0: self.codebook_loss_net = JointCodebookLoss( predictor_channels=encoder_dim, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 65895c920..47cf2b14b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -28,18 +28,21 @@ from typing import List, Tuple import numpy as np import torch import torch.multiprocessing as mp -import multi_quantization as quantization +from icefall import is_module_available + +if not is_module_available("multi_quantization"): + raise ValueError("Please 'pip install multi_quantization' first.") + +import multi_quantization as quantization from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from icefall.utils import ( - AttributeDict, - setup_logger, -) from lhotse import CutSet, load_manifest from lhotse.cut import MonoCut from lhotse.features.io import NumpyHdf5Writer +from icefall.utils import AttributeDict, setup_logger + class CodebookIndexExtractor: """ diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py index 91877ec46..c396c50ef 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py @@ -40,6 +40,11 @@ https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_s import argparse import logging +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py index 132517352..3770fbbb4 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py @@ -49,6 +49,12 @@ from typing import List import k2 import kaldifeat import numpy as np + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch import torchaudio diff --git a/icefall/__init__.py b/icefall/__init__.py index 122226fdc..27ad74213 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -50,6 +50,7 @@ from .utils import ( get_executor, get_texts, is_jit_tracing, + is_module_available, l1_norm, l2_norm, linf_norm, diff --git a/icefall/ngram_lm.py b/icefall/ngram_lm.py index 23185e35a..63885a9d0 100644 --- a/icefall/ngram_lm.py +++ b/icefall/ngram_lm.py @@ -17,7 +17,7 @@ from collections import defaultdict from typing import List, Optional, Tuple -import kaldifst +from icefall.utils import is_module_available class NgramLm: @@ -36,6 +36,11 @@ class NgramLm: is_binary: True if the given file is a binary FST. """ + if not is_module_available("kaldifst"): + raise ValueError("Please 'pip install kaldifst' first.") + + import kaldifst + if is_binary: lm = kaldifst.StdVectorFst.read(fst_filename) else: @@ -85,6 +90,8 @@ class NgramLm: self, state: int, label: int ) -> Tuple[int, float]: """TODO: Add doc.""" + import kaldifst + arc_iter = kaldifst.ArcIterator(self.lm, state) num_arcs = self.lm.num_arcs(state) diff --git a/icefall/utils.py b/icefall/utils.py index ad079222e..6c115ed16 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -976,3 +976,17 @@ def display_and_save_batch( y = sp.encode(supervisions["text"], out_type=int) num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}") + + +# `is_module_available` is copied from +# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9 +def is_module_available(*modules: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. This is generally safer than try-catch block around a + `import X`. + + Note: "borrowed" from torchaudio: + """ + import importlib + + return all(importlib.util.find_spec(m) is not None for m in modules) diff --git a/requirements.txt b/requirements.txt index 1c548c50a..5e32af853 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,9 +3,4 @@ kaldialign sentencepiece>=0.1.96 tensorboard typeguard -multi_quantization -onnx -onnxruntime ---extra-index-url https://pypi.ngc.nvidia.com dill -kaldifst diff --git a/test/test_ngram_lm.py b/test/test_ngram_lm.py index bbf6bd51c..838c792d2 100755 --- a/test/test_ngram_lm.py +++ b/test/test_ngram_lm.py @@ -16,6 +16,12 @@ # limitations under the License. import graphviz + +from icefall import is_module_available + +if not is_module_available("kaldifst"): + raise ValueError("Please 'pip install kaldifst' first.") + import kaldifst from icefall import NgramLm, NgramLmStateCost