mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Support exporting LSTM with projection to ONNX (#621)
* Support exporting LSTM with projection to ONNX * Add missing files * small fixes
This commit is contained in:
parent
d1f16a04bd
commit
d69bb826ed
@ -42,7 +42,7 @@ for sym in 1 2 3; do
|
||||
--lang-dir $repo/data/lang_char \
|
||||
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||
$rep/test_wavs/BAC009S0764W0123.wav
|
||||
$repo/test_wavs/BAC009S0764W0123.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search fast_beam_search; do
|
||||
@ -55,7 +55,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
||||
--lang-dir $repo/data/lang_char \
|
||||
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||
$rep/test_wavs/BAC009S0764W0123.wav
|
||||
$repo/test_wavs/BAC009S0764W0123.wav
|
||||
done
|
||||
|
||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||
|
@ -105,6 +105,47 @@ log "Decode with models exported by torch.jit.trace()"
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Test exporting to ONNX"
|
||||
|
||||
./lstm_transducer_stateless2/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--onnx 1
|
||||
|
||||
log "Decode with ONNX models "
|
||||
|
||||
./lstm_transducer_stateless2/streaming-onnx-decode.py \
|
||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo//exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
./lstm_transducer_stateless2/streaming-onnx-decode.py \
|
||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo//exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
||||
$repo/test_wavs/1221-135766-0001.wav
|
||||
|
||||
./lstm_transducer_stateless2/streaming-onnx-decode.py \
|
||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo//exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
@ -133,7 +174,7 @@ done
|
||||
|
||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then
|
||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
||||
mkdir -p lstm_transducer_stateless2/exp
|
||||
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
|
||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||
|
@ -13,10 +13,14 @@ cd egs/librispeech/ASR
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt"
|
||||
popd
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
|
@ -1,4 +1,4 @@
|
||||
name: run-librispeech-lstm-transducer-2022-09-03
|
||||
name: run-librispeech-lstm-transducer2-2022-09-03
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -17,8 +17,8 @@ on:
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
jobs:
|
||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||
if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
run_librispeech_lstm_transducer_stateless2_2022_09_03:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
@ -110,7 +110,7 @@ jobs:
|
||||
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
|
||||
|
||||
- name: Display decoding results for lstm_transducer_stateless2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
|
||||
if: github.event_name == 'schedule'
|
||||
shell: bash
|
||||
run: |
|
||||
cd egs/librispeech/ASR
|
||||
@ -130,7 +130,7 @@ jobs:
|
||||
|
||||
- name: Upload decoding results for lstm_transducer_stateless2
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
|
||||
if: github.event_name == 'schedule'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
||||
|
1
egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../lstm_transducer_stateless2/lstmp.py
|
@ -74,6 +74,29 @@ with the following commands:
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||
# You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp
|
||||
|
||||
(3) Export to ONNX format
|
||||
|
||||
./lstm_transducer_stateless2/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Please see ./streaming-onnx-decode.py for usage of the generated files
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -181,6 +204,23 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit and --pnnx are ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -266,6 +306,215 @@ def export_joiner_model_jit_trace(
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has 3 inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
- states: a tuple containing:
|
||||
- h0: a tensor of shape (num_layers, N, proj_size)
|
||||
- c0: a tensor of shape (num_layers, N, hidden_size)
|
||||
|
||||
and it has 3 outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
- states: a tuple containing:
|
||||
- next_h0: a tensor of shape (num_layers, N, proj_size)
|
||||
- next_c0: a tensor of shape (num_layers, N, hidden_size)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
N = 1
|
||||
x = torch.zeros(N, 9, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([9], dtype=torch.int64)
|
||||
h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model)
|
||||
c = torch.rand(
|
||||
encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size
|
||||
)
|
||||
|
||||
warmup = 1.0
|
||||
torch.onnx.export(
|
||||
encoder_model, # use torch.jit.trace() internally
|
||||
(x, x_lens, (h, c), warmup),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "h", "c", "warmup"],
|
||||
output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"h": {1: "N"},
|
||||
"c": {1: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
"next_h": {1: "N"},
|
||||
"next_c": {1: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
|
||||
The exported encoder_proj model has one input:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
The exported decoder_proj model has one input:
|
||||
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_encoder_proj.onnx"
|
||||
)
|
||||
|
||||
decoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_decoder_proj.onnx"
|
||||
)
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
project_input = False
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"projected_encoder_out",
|
||||
"projected_decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"projected_encoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
@ -387,7 +636,33 @@ def main():
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.pnnx:
|
||||
if params.onnx:
|
||||
logging.info("Export model to ONNX format")
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||
|
||||
opset_version = 11
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
elif params.pnnx:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||
|
102
egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py
Normal file
102
egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LSTMP(nn.Module):
|
||||
"""LSTM with projection.
|
||||
|
||||
PyTorch does not support exporting LSTM with projection to ONNX.
|
||||
This class reimplements LSTM with projection using basic matrix-matrix
|
||||
and matrix-vector operations. It is not intended for training.
|
||||
"""
|
||||
|
||||
def __init__(self, lstm: nn.LSTM):
|
||||
"""
|
||||
Args:
|
||||
lstm:
|
||||
LSTM with proj_size. We support only uni-directional,
|
||||
1-layer LSTM with projection at present.
|
||||
"""
|
||||
super().__init__()
|
||||
assert lstm.bidirectional is False, lstm.bidirectional
|
||||
assert lstm.num_layers == 1, lstm.num_layers
|
||||
assert 0 < lstm.proj_size < lstm.hidden_size, (
|
||||
lstm.proj_size,
|
||||
lstm.hidden_size,
|
||||
)
|
||||
|
||||
assert lstm.batch_first is False, lstm.batch_first
|
||||
|
||||
state_dict = lstm.state_dict()
|
||||
|
||||
w_ih = state_dict["weight_ih_l0"]
|
||||
w_hh = state_dict["weight_hh_l0"]
|
||||
|
||||
b_ih = state_dict["bias_ih_l0"]
|
||||
b_hh = state_dict["bias_hh_l0"]
|
||||
|
||||
w_hr = state_dict["weight_hr_l0"]
|
||||
self.input_size = lstm.input_size
|
||||
self.proj_size = lstm.proj_size
|
||||
self.hidden_size = lstm.hidden_size
|
||||
|
||||
self.w_ih = w_ih
|
||||
self.w_hh = w_hh
|
||||
self.b = b_ih + b_hh
|
||||
self.w_hr = w_hr
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
input:
|
||||
A tensor of shape [T, N, hidden_size]
|
||||
hx:
|
||||
A tuple containing:
|
||||
- h0: a tensor of shape (1, N, proj_size)
|
||||
- c0: a tensor of shape (1, N, hidden_size)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- output: a tensor of shape (T, N, proj_size).
|
||||
- A tuple containing:
|
||||
- h: a tensor of shape (1, N, proj_size)
|
||||
- c: a tensor of shape (1, N, hidden_size)
|
||||
|
||||
"""
|
||||
x_list = input.unbind(dim=0) # We use batch_first=False
|
||||
|
||||
if hx is not None:
|
||||
h0, c0 = hx
|
||||
else:
|
||||
h0 = torch.zeros(1, input.size(1), self.proj_size)
|
||||
c0 = torch.zeros(1, input.size(1), self.hidden_size)
|
||||
h0 = h0.squeeze(0)
|
||||
c0 = c0.squeeze(0)
|
||||
y_list = []
|
||||
for x in x_list:
|
||||
gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh)
|
||||
i, f, g, o = gates.chunk(4, dim=1)
|
||||
|
||||
i = i.sigmoid()
|
||||
f = f.sigmoid()
|
||||
g = g.tanh()
|
||||
o = o.sigmoid()
|
||||
|
||||
c = f * c0 + i * g
|
||||
h = o * c.tanh()
|
||||
|
||||
h = F.linear(h, self.w_hr)
|
||||
y_list.append(h)
|
||||
|
||||
c0 = c
|
||||
h0 = h
|
||||
|
||||
y = torch.stack(y_list, dim=0)
|
||||
|
||||
return y, (h0.unsqueeze(0), c0.unsqueeze(0))
|
@ -233,13 +233,12 @@ def greedy_search(
|
||||
hyp, dtype=torch.int32
|
||||
) # (1, context_size)
|
||||
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||
|
||||
else:
|
||||
assert decoder_out.ndim == 1
|
||||
assert hyp is not None, hyp
|
||||
|
||||
joiner_out = model.run_joiner(encoder_out, decoder_out)
|
||||
y = joiner_out.argmax(dim=0).tolist()
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
|
478
egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
Executable file
478
egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
Executable file
@ -0,0 +1,478 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./lstm_transducer_stateless2/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./lstm_transducer_stateless2/onnx-streaming-decode.py \
|
||||
--encoder-model-filename ./lstm_transducer_stateless2/exp/encoder.onnx \
|
||||
--decoder-model-filename ./lstm_transducer_stateless2/exp/decoder.onnx \
|
||||
--joiner-model-filename ./lstm_transducer_stateless2/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename ./lstm_transducer_stateless2/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename ./lstm_transducer_stateless2/exp/joiner_decoder_proj.onnx \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import onnxruntime as ort
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model-filename",
|
||||
type=str,
|
||||
help="Path to bpe.model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-encoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner encoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-decoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner decoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_filename",
|
||||
type=str,
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, args):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 5
|
||||
session_opts.intra_op_num_threads = 5
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(args)
|
||||
self.init_decoder(args)
|
||||
self.init_joiner(args)
|
||||
self.init_joiner_encoder_proj(args)
|
||||
self.init_joiner_decoder_proj(args)
|
||||
|
||||
def init_encoder(self, args):
|
||||
self.encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def init_decoder(self, args):
|
||||
self.decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def init_joiner(self, args):
|
||||
self.joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def init_joiner_encoder_proj(self, args):
|
||||
self.joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def init_joiner_decoder_proj(self, args):
|
||||
self.joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def run_encoder(
|
||||
self, x, h0, c0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (N, T, C)
|
||||
h0:
|
||||
A tensor of shape (num_layers, N, proj_size)
|
||||
c0:
|
||||
A tensor of shape (num_layers, N, hidden_size)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out: A tensor of shape (N, T', C')
|
||||
- next_h0: A tensor of shape (num_layers, N, proj_size)
|
||||
- next_c0: A tensor of shape (num_layers, N, hidden_size)
|
||||
"""
|
||||
encoder_input_nodes = self.encoder.get_inputs()
|
||||
encoder_out_nodes = self.encoder.get_outputs()
|
||||
x_lens = torch.tensor([x.size(1)], dtype=torch.int64)
|
||||
|
||||
encoder_out, encoder_out_lens, next_h0, next_c0 = self.encoder.run(
|
||||
[
|
||||
encoder_out_nodes[0].name,
|
||||
encoder_out_nodes[1].name,
|
||||
encoder_out_nodes[2].name,
|
||||
encoder_out_nodes[3].name,
|
||||
],
|
||||
{
|
||||
encoder_input_nodes[0].name: x.numpy(),
|
||||
encoder_input_nodes[1].name: x_lens.numpy(),
|
||||
encoder_input_nodes[2].name: h0.numpy(),
|
||||
encoder_input_nodes[3].name: c0.numpy(),
|
||||
},
|
||||
)
|
||||
return (
|
||||
torch.from_numpy(encoder_out),
|
||||
torch.from_numpy(next_h0),
|
||||
torch.from_numpy(next_c0),
|
||||
)
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A tensor of shape (N, context_size). Its dtype is torch.int64.
|
||||
Returns:
|
||||
Return a tensor of shape (N, 1, decoder_out_dim).
|
||||
"""
|
||||
decoder_input_nodes = self.decoder.get_inputs()
|
||||
decoder_output_nodes = self.decoder.get_outputs()
|
||||
|
||||
decoder_out = self.decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return self.run_joiner_decoder_proj(
|
||||
torch.from_numpy(decoder_out).squeeze(1)
|
||||
)
|
||||
|
||||
def run_joiner(
|
||||
self,
|
||||
projected_encoder_out: torch.Tensor,
|
||||
projected_decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
projected_encoder_out:
|
||||
A tensor of shape (N, joiner_dim)
|
||||
projected_decoder_out:
|
||||
A tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_input_nodes = self.joiner.get_inputs()
|
||||
joiner_output_nodes = self.joiner.get_outputs()
|
||||
|
||||
logits = self.joiner.run(
|
||||
[joiner_output_nodes[0].name],
|
||||
{
|
||||
joiner_input_nodes[0].name: projected_encoder_out.numpy(),
|
||||
joiner_input_nodes[1].name: projected_decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def run_joiner_encoder_proj(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A tensor of shape (N, encoder_out_dim)
|
||||
Returns:
|
||||
A tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
|
||||
projected_encoder_out = self.joiner_encoder_proj.run(
|
||||
[self.joiner_encoder_proj.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner_encoder_proj.get_inputs()[
|
||||
0
|
||||
].name: encoder_out.numpy()
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(projected_encoder_out)
|
||||
|
||||
def run_joiner_decoder_proj(
|
||||
self,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_out:
|
||||
A tensor of shape (N, decoder_out_dim)
|
||||
Returns:
|
||||
A tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
|
||||
projected_decoder_out = self.joiner_decoder_proj.run(
|
||||
[self.joiner_decoder_proj.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner_decoder_proj.get_inputs()[
|
||||
0
|
||||
].name: decoder_out.numpy()
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(projected_decoder_out)
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Model,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
assert encoder_out.ndim == 2
|
||||
assert encoder_out.shape[0] == 1, "TODO: support batch_size > 1"
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(
|
||||
[hyp], dtype=torch.int64
|
||||
) # (1, context_size)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert decoder_out.shape[0] == 1
|
||||
assert hyp is not None, hyp
|
||||
|
||||
projected_encoder_out = model.run_joiner_encoder_proj(encoder_out)
|
||||
|
||||
joiner_out = model.run_joiner(projected_encoder_out, decoder_out)
|
||||
y = joiner_out.squeeze(0).argmax(dim=0).item()
|
||||
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = Model(args)
|
||||
|
||||
sound_file = args.sound_filename
|
||||
sample_rate = 16000
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model_filename)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
num_encoder_layers = 12
|
||||
batch_size = 1
|
||||
d_model = 512
|
||||
rnn_hidden_size = 1024
|
||||
|
||||
h0 = torch.zeros(num_encoder_layers, batch_size, d_model)
|
||||
c0 = torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size)
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = 9
|
||||
offset = 4
|
||||
|
||||
chunk = 3200 # 0.2 second
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0).unsqueeze(0)
|
||||
encoder_out, h0, c0 = model.run_encoder(frames, h0, c0)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model, encoder_out.squeeze(0), decoder_out, hyp
|
||||
)
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate, waveform=torch.zeros(5000, dtype=torch.float)
|
||||
)
|
||||
|
||||
online_fbank.input_finished()
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0).unsqueeze(0)
|
||||
encoder_out, h0, c0 = model.run_encoder(frames, h0, c0)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model, encoder_out.squeeze(0), decoder_out, hyp
|
||||
)
|
||||
|
||||
context_size = 2
|
||||
|
||||
logging.info(sound_file)
|
||||
logging.info(sp.decode(hyp[context_size:]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
70
egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py
Executable file
70
egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py
Executable file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lstmp import LSTMP
|
||||
|
||||
|
||||
def test():
|
||||
input_size = torch.randint(low=10, high=1024, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=10, high=1024, size=(1,)).item()
|
||||
proj_size = hidden_size - 1
|
||||
lstm = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
bias=True,
|
||||
proj_size=proj_size,
|
||||
)
|
||||
lstmp = LSTMP(lstm)
|
||||
|
||||
N = torch.randint(low=1, high=10, size=(1,)).item()
|
||||
T = torch.randint(low=1, high=20, size=(1,)).item()
|
||||
x = torch.rand(T, N, input_size)
|
||||
h0 = torch.rand(1, N, proj_size)
|
||||
c0 = torch.rand(1, N, hidden_size)
|
||||
|
||||
y1, (h1, c1) = lstm(x, (h0, c0))
|
||||
y2, (h2, c2) = lstmp(x, (h0, c0))
|
||||
|
||||
assert torch.allclose(y1, y2, atol=1e-5), (y1 - y2).abs().max()
|
||||
assert torch.allclose(h1, h2, atol=1e-5), (h1 - h2).abs().max()
|
||||
assert torch.allclose(c1, c2, atol=1e-5), (c1 - c2).abs().max()
|
||||
|
||||
# lstm_script = torch.jit.script(lstm) # pytorch does not support it
|
||||
lstm_script = lstm
|
||||
lstmp_script = torch.jit.script(lstmp)
|
||||
|
||||
y3, (h3, c3) = lstm_script(x, (h0, c0))
|
||||
y4, (h4, c4) = lstmp_script(x, (h0, c0))
|
||||
|
||||
assert torch.allclose(y3, y4, atol=1e-5), (y3 - y4).abs().max()
|
||||
assert torch.allclose(h3, h4, atol=1e-5), (h3 - h4).abs().max()
|
||||
assert torch.allclose(c3, c4, atol=1e-5), (c3 - c4).abs().max()
|
||||
|
||||
assert torch.allclose(y3, y1, atol=1e-5), (y3 - y1).abs().max()
|
||||
assert torch.allclose(h3, h1, atol=1e-5), (h3 - h1).abs().max()
|
||||
assert torch.allclose(c3, c1, atol=1e-5), (c3 - c1).abs().max()
|
||||
|
||||
lstm_trace = torch.jit.trace(lstm, (x, (h0, c0)))
|
||||
lstmp_trace = torch.jit.trace(lstmp, (x, (h0, c0)))
|
||||
|
||||
y5, (h5, c5) = lstm_trace(x, (h0, c0))
|
||||
y6, (h6, c6) = lstmp_trace(x, (h0, c0))
|
||||
|
||||
assert torch.allclose(y5, y6, atol=1e-5), (y5 - y6).abs().max()
|
||||
assert torch.allclose(h5, h6, atol=1e-5), (h5 - h6).abs().max()
|
||||
assert torch.allclose(c5, c6, atol=1e-5), (c5 - c6).abs().max()
|
||||
|
||||
assert torch.allclose(y5, y1, atol=1e-5), (y5 - y1).abs().max()
|
||||
assert torch.allclose(h5, h1, atol=1e-5), (h5 - h1).abs().max()
|
||||
assert torch.allclose(c5, c1, atol=1e-5), (c5 - c1).abs().max()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../lstm_transducer_stateless2/lstmp.py
|
1
egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py
Symbolic link
1
egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../lstm_transducer_stateless2/lstmp.py
|
@ -29,6 +29,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lstmp import LSTMP
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
@ -259,7 +260,11 @@ def get_submodule(model, target):
|
||||
return mod
|
||||
|
||||
|
||||
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||
def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
is_onnx: bool = False,
|
||||
):
|
||||
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
||||
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
||||
and `nn.Conv2d`.
|
||||
@ -270,6 +275,9 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||
inplace:
|
||||
If True, the input model is modified inplace.
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
is_onnx:
|
||||
If True, we are going to export the model to ONNX. In this case,
|
||||
we will convert nn.LSTM with proj_size to LSTMP.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
@ -294,6 +302,12 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||
elif isinstance(m, BasicNorm):
|
||||
d[name] = convert_basic_norm(m)
|
||||
elif isinstance(m, ScaledLSTM):
|
||||
if is_onnx:
|
||||
d[name] = LSTMP(scaled_lstm_to_lstm(m))
|
||||
# See
|
||||
# https://github.com/pytorch/pytorch/issues/47887
|
||||
# d[name] = torch.jit.script(LSTMP(scaled_lstm_to_lstm(m)))
|
||||
else:
|
||||
d[name] = scaled_lstm_to_lstm(m)
|
||||
elif isinstance(m, ActivationBalancer):
|
||||
d[name] = nn.Identity()
|
||||
|
1
egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
1
egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
|
Loading…
x
Reference in New Issue
Block a user