2022-07-30 21:17:31 +08:00

162 lines
4.2 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ncnn models and uses them to decode waves.
./pruned_transducer_stateless3/jit_pretrained.py \
--model-dir /path/to/ncnn/model_dir
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
We assume there exist following files in the given `model_dir`:
- encoder_jit_trace.ncnn.param
- encoder_jit_trace.ncnn.bin
- decoder_jit_trace.ncnn.param
- decoder_jit_trace.ncnn.bin
- joiner_jit_trace.ncnn.param
- joiner_jit_trace.ncnn.bin
"""
import argparse
import logging
from pathlib import Path
from typing import List
import ncnn
import torch
import torchaudio
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-dir",
type=str,
required=True,
help="Path to the ncnn models directory. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model_dir = Path(args.model_dir)
encoder_param = model_dir / "encoder_jit_trace.ncnn.param"
encoder_bin = model_dir / "encoder_jit_trace.ncnn.bin"
decoder_param = model_dir / "decoder_jit_trace.ncnn.param"
decoder_bin = model_dir / "decoder_jit_trace.ncnn.bin"
joiner_param = model_dir / "joiner_jit_trace.ncnn.param"
joiner_bin = model_dir / "joiner_jit_trace.ncnn.bin"
assert encoder_param.is_file()
assert encoder_bin.is_file()
assert decoder_param.is_file()
assert decoder_bin.is_file()
assert joiner_param.is_file()
assert joiner_bin.is_file()
encoder = ncnn.Net()
decoder = ncnn.Net()
joiner = ncnn.Net()
# encoder.load_param(str(encoder_param)) # not working yet
# decoder.load_param(str(decoder_param))
joiner.load_param(str(joiner_param))
encoder.clear()
decoder.clear()
joiner.clear()
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()