mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Remove ncnn as it is not fully tested yet.
This commit is contained in:
parent
0d5edbc3f2
commit
6e43a2b69d
@ -1,161 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ncnn models and uses them to decode waves.
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--model-dir /path/to/ncnn/model_dir
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
We assume there exist following files in the given `model_dir`:
|
||||
|
||||
- encoder_jit_trace.ncnn.param
|
||||
- encoder_jit_trace.ncnn.bin
|
||||
- decoder_jit_trace.ncnn.param
|
||||
- decoder_jit_trace.ncnn.bin
|
||||
- joiner_jit_trace.ncnn.param
|
||||
- joiner_jit_trace.ncnn.bin
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import ncnn
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the ncnn models directory. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model_dir = Path(args.model_dir)
|
||||
encoder_param = model_dir / "encoder_jit_trace.ncnn.param"
|
||||
encoder_bin = model_dir / "encoder_jit_trace.ncnn.bin"
|
||||
|
||||
decoder_param = model_dir / "decoder_jit_trace.ncnn.param"
|
||||
decoder_bin = model_dir / "decoder_jit_trace.ncnn.bin"
|
||||
|
||||
joiner_param = model_dir / "joiner_jit_trace.ncnn.param"
|
||||
joiner_bin = model_dir / "joiner_jit_trace.ncnn.bin"
|
||||
|
||||
assert encoder_param.is_file()
|
||||
assert encoder_bin.is_file()
|
||||
|
||||
assert decoder_param.is_file()
|
||||
assert decoder_bin.is_file()
|
||||
|
||||
assert joiner_param.is_file()
|
||||
assert joiner_bin.is_file()
|
||||
|
||||
encoder = ncnn.Net()
|
||||
decoder = ncnn.Net()
|
||||
joiner = ncnn.Net()
|
||||
|
||||
# encoder.load_param(str(encoder_param)) # not working yet
|
||||
# decoder.load_param(str(decoder_param))
|
||||
joiner.load_param(str(joiner_param))
|
||||
|
||||
encoder.clear()
|
||||
decoder.clear()
|
||||
joiner.clear()
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user