mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Convert conv-emformer to ncnn (#717)
* Export conv-emformer via torch.jit.trace()
This commit is contained in:
parent
be6e08f69a
commit
f13cf61b05
79
.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
vendored
Executable file
79
.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
vendored
Executable file
@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
cd exp
|
||||
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
log "Install ncnn and pnnx"
|
||||
|
||||
# We are using a modified ncnn here. Will try to merge it to the official repo
|
||||
# of ncnn
|
||||
git clone https://github.com/csukuangfj/ncnn
|
||||
pushd ncnn
|
||||
git submodule init
|
||||
git submodule update python/pybind11
|
||||
python3 setup.py bdist_wheel
|
||||
ls -lh dist/
|
||||
pip install dist/*.whl
|
||||
cd tools/pnnx
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 ..
|
||||
make -j4 pnnx
|
||||
|
||||
./src/pnnx || echo "pass"
|
||||
|
||||
popd
|
||||
|
||||
log "Test exporting to pnnx format"
|
||||
|
||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
\
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32
|
||||
|
||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
||||
|
||||
./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
77
.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
vendored
Normal file
77
.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
vendored
Normal file
@ -0,0 +1,77 @@
|
||||
name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
schedule:
|
||||
# minute (0-59)
|
||||
# hour (0-23)
|
||||
# day of the month (1-31)
|
||||
# month (1-12)
|
||||
# day of the week (0-6)
|
||||
# nightly build at 15:50 UTC time every day
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
jobs:
|
||||
run_librispeech_conv_emformer_transducer_stateless2_2022_12_05:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
mkdir -p egs/librispeech/ASR/data
|
||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||
ls -lh egs/librispeech/ASR/data/*
|
||||
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
|
@ -111,7 +111,7 @@ jobs:
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
|
||||
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
|
||||
|
||||
- name: Display decoding results for lstm_transducer_stateless2
|
||||
if: github.event_name == 'schedule'
|
||||
|
1798
egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
Normal file
1798
egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
Normal file
File diff suppressed because it is too large
Load Diff
335
egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
Executable file
335
egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
Executable file
@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
|
||||
cd ./conv_emformer_transducer_stateless2/exp
|
||||
pnnx encoder_jit_trace-pnnx.pt
|
||||
pnnx decoder_jit_trace-pnnx.pt
|
||||
pnnx joiner_jit_trace-pnnx.pt
|
||||
|
||||
You can find converted models at
|
||||
https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04
|
||||
|
||||
See ./streaming-ncnn-decode.py
|
||||
and
|
||||
https://github.com/k2-fsa/sherpa-ncnn
|
||||
for usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
chunk_length = encoder_model.chunk_length # before subsampling
|
||||
right_context_length = encoder_model.right_context_length # before subsampling
|
||||
pad_length = right_context_length + 2 * 4 + 3
|
||||
s = f"chunk_length: {chunk_length}, "
|
||||
s += f"right_context_length: {right_context_length}\n"
|
||||
logging.info(s)
|
||||
|
||||
T = chunk_length + pad_length
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
states = encoder_model.init_states()
|
||||
states = encoder_model.init_states()
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
292
egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
Executable file
292
egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
Executable file
@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models exported by `torch.jit.trace()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./conv_emformer_transducer_stateless2/jit_pretrained.py \
|
||||
--encoder-model-filename ./conv_emformer_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt \
|
||||
--decoder-model-filename ./conv_emformer_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt \
|
||||
--joiner-model-filename ./conv_emformer_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
assert encoder_out.ndim == 2
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
|
||||
# decoder_input.shape (1,, 1 context_size)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1)
|
||||
else:
|
||||
assert decoder_out.ndim == 2
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for i in range(T):
|
||||
cur_encoder_out = encoder_out[i : i + 1]
|
||||
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
|
||||
decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
encoder = torch.jit.load(args.encoder_model_filename)
|
||||
decoder = torch.jit.load(args.decoder_model_filename)
|
||||
joiner = torch.jit.load(args.joiner_model_filename)
|
||||
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
joiner.eval()
|
||||
|
||||
encoder.to(device)
|
||||
decoder.to(device)
|
||||
joiner.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor(args.sample_rate)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
logging.info("Decoding started")
|
||||
chunk_length = encoder.chunk_length
|
||||
right_context_length = encoder.right_context_length
|
||||
|
||||
# Assume the subsampling factor is 4
|
||||
pad_length = right_context_length + 2 * 4 + 3
|
||||
T = chunk_length + pad_length
|
||||
|
||||
logging.info(f"chunk_length: {chunk_length}")
|
||||
logging.info(f"right_context_length: {right_context_length}")
|
||||
|
||||
states = encoder.init_states(device)
|
||||
logging.info(f"num layers: {len(states)//4}")
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
chunk = int(0.25 * args.sample_rate) # 0.2 second
|
||||
num_processed_frames = 0
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
logging.info(f"{start}/{wave_samples.numel()}")
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=args.sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= T:
|
||||
frames = []
|
||||
for i in range(T):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += chunk_length
|
||||
frames = torch.cat(frames, dim=0).unsqueeze(0)
|
||||
# TODO(fangjun): remove x_lens
|
||||
x_lens = torch.tensor([T])
|
||||
encoder_out, _, states = encoder(frames, x_lens, states)
|
||||
|
||||
hyp, decoder_out = greedy_search(
|
||||
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
|
||||
)
|
||||
|
||||
context_size = 2
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(sp.decode(hyp[context_size:]))
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
Symbolic link
1
egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../lstm_transducer_stateless2/lstmp.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless3/scaling_converter.py
|
387
egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
Executable file
387
egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
Executable file
@ -0,0 +1,387 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
|
||||
--tokens ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/tokens.txt \
|
||||
--encoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
|
||||
--encoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
|
||||
--decoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
|
||||
--decoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
|
||||
--joiner-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
|
||||
--joiner-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
|
||||
./sherpa-ncnn-conv-emformer-transducer-2022-12-04/test_wavs/1089-134686-0001.wav
|
||||
|
||||
You can find pretrained models at
|
||||
https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import ncnn
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-param-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-param-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-param-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-bin-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_filename",
|
||||
type=str,
|
||||
help="Path to foo.wav",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, args):
|
||||
self.init_encoder(args)
|
||||
self.init_decoder(args)
|
||||
self.init_joiner(args)
|
||||
|
||||
self.num_layers = 12
|
||||
self.memory_size = 32
|
||||
self.d_model = 512
|
||||
self.cnn_module_kernel = 31
|
||||
|
||||
self.left_context_length = 32 // 4 # after subsampling
|
||||
self.chunk_length = 32 # before subsampling
|
||||
right_context_length = 8 # before subsampling
|
||||
pad_length = right_context_length + 2 * 4 + 3
|
||||
self.T = self.chunk_length + pad_length
|
||||
print("T", self.T, self.chunk_length)
|
||||
|
||||
def get_init_states(self) -> List[torch.Tensor]:
|
||||
states = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
s0 = torch.zeros(self.memory_size, self.d_model)
|
||||
s1 = torch.zeros(self.left_context_length, self.d_model)
|
||||
s2 = torch.zeros(self.left_context_length, self.d_model)
|
||||
s3 = torch.zeros(self.d_model, self.cnn_module_kernel - 1)
|
||||
states.extend([s0, s1, s2, s3])
|
||||
|
||||
return states
|
||||
|
||||
def init_encoder(self, args):
|
||||
encoder_net = ncnn.Net()
|
||||
encoder_net.opt.use_packing_layout = False
|
||||
encoder_net.opt.use_fp16_storage = False
|
||||
encoder_param = args.encoder_param_filename
|
||||
encoder_model = args.encoder_bin_filename
|
||||
|
||||
encoder_net.load_param(encoder_param)
|
||||
encoder_net.load_model(encoder_model)
|
||||
|
||||
self.encoder_net = encoder_net
|
||||
|
||||
def init_decoder(self, args):
|
||||
decoder_param = args.decoder_param_filename
|
||||
decoder_model = args.decoder_bin_filename
|
||||
|
||||
decoder_net = ncnn.Net()
|
||||
|
||||
decoder_net.load_param(decoder_param)
|
||||
decoder_net.load_model(decoder_model)
|
||||
|
||||
self.decoder_net = decoder_net
|
||||
|
||||
def init_joiner(self, args):
|
||||
joiner_param = args.joiner_param_filename
|
||||
joiner_model = args.joiner_bin_filename
|
||||
joiner_net = ncnn.Net()
|
||||
joiner_net.load_param(joiner_param)
|
||||
joiner_net.load_model(joiner_model)
|
||||
|
||||
self.joiner_net = joiner_net
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (T, C)
|
||||
states:
|
||||
A list of tensors. len(states) == self.num_layers * 4
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, a tensor of shape (T, encoder_dim).
|
||||
- next_states, a list of tensors containing the next states
|
||||
"""
|
||||
with self.encoder_net.create_extractor() as ex:
|
||||
ex.set_num_threads(4)
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
|
||||
# layer0 in2-in5
|
||||
# layer1 in6-in9
|
||||
for i in range(self.num_layers):
|
||||
offset = 1 + i * 4
|
||||
name = f"in{offset}"
|
||||
# (32, 1, 512) -> (32, 512)
|
||||
ex.input(name, ncnn.Mat(states[i * 4 + 0].numpy()).clone())
|
||||
|
||||
name = f"in{offset+1}"
|
||||
# (8, 1, 512) -> (8, 512)
|
||||
ex.input(name, ncnn.Mat(states[i * 4 + 1].numpy()).clone())
|
||||
|
||||
name = f"in{offset+2}"
|
||||
# (8, 1, 512) -> (8, 512)
|
||||
ex.input(name, ncnn.Mat(states[i * 4 + 2].numpy()).clone())
|
||||
|
||||
name = f"in{offset+3}"
|
||||
# (1, 512, 2) -> (512, 2)
|
||||
ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone())
|
||||
|
||||
import pdb
|
||||
|
||||
# pdb.set_trace()
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
# assert ret == 0, ret
|
||||
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
|
||||
out_states: List[torch.Tensor] = []
|
||||
for i in range(4 * self.num_layers):
|
||||
name = f"out{i+1}"
|
||||
ret, ncnn_out_state = ex.extract(name)
|
||||
assert ret == 0, ret
|
||||
ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
|
||||
out_states.append(ncnn_out_state)
|
||||
|
||||
return encoder_out, out_states
|
||||
|
||||
def run_decoder(self, decoder_input):
|
||||
assert decoder_input.dtype == torch.int32
|
||||
|
||||
with self.decoder_net.create_extractor() as ex:
|
||||
ex.set_num_threads(4)
|
||||
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
return decoder_out
|
||||
|
||||
def run_joiner(self, encoder_out, decoder_out):
|
||||
with self.joiner_net.create_extractor() as ex:
|
||||
ex.set_num_threads(4)
|
||||
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||
return joiner_out
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Model,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
|
||||
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||
else:
|
||||
assert decoder_out.ndim == 1
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t]
|
||||
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
|
||||
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = Model(args)
|
||||
|
||||
sound_file = args.sound_filename
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
states = model.get_init_states()
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.T
|
||||
offset = model.chunk_length
|
||||
|
||||
chunk = int(1 * sample_rate) # 0.2 second
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
encoder_out, states = model.run_encoder(frames, states)
|
||||
hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
context_size = 2
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(sound_file)
|
||||
logging.info(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
1128
egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
Executable file
1128
egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
Executable file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user