mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
support jit trace
This commit is contained in:
parent
03b056ca37
commit
45c7894111
581
egs/librispeech/ASR/lstm_transducer_stateless/export.py
Executable file
581
egs/librispeech/ASR/lstm_transducer_stateless/export.py
Executable file
@ -0,0 +1,581 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.trace()
|
||||||
|
|
||||||
|
./lstm_transducer_stateless/export.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
It will generates 3 files: `encoder_jit_trace.pt`,
|
||||||
|
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
|
||||||
|
|
||||||
|
|
||||||
|
(3) Export to ONNX format
|
||||||
|
|
||||||
|
./lstm_transducer_stateless/export.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
It will generate the following three files in the given `exp_dir`.
|
||||||
|
Check `onnx_check.py` for how to use them.
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
|
||||||
|
|
||||||
|
(4) Export `model.state_dict()`
|
||||||
|
|
||||||
|
./lstm_transducer_stateless/export.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
|
To use the generated file with `lstm_transducer_stateless/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
./lstm_transducer_stateless/decode.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
Check ./pretrained.py for its usage.
|
||||||
|
|
||||||
|
Note: If you don't want to train a model from scratch, we have
|
||||||
|
provided one for you. You can get it at
|
||||||
|
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||||
|
|
||||||
|
with the following commands:
|
||||||
|
|
||||||
|
sudo apt-get install git-lfs
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||||
|
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless3/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-trace",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.trace.
|
||||||
|
It will generate 3 files:
|
||||||
|
- encoder_jit_trace.pt
|
||||||
|
- decoder_jit_trace.pt
|
||||||
|
- joiner_jit_trace.pt
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for how to use them.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If True, --jit is ignored and it exports the model
|
||||||
|
to onnx format. Three files will be generated:
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
|
||||||
|
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
states = encoder_model.get_init_states()
|
||||||
|
states = (states[0].unsqueeze(1), states[1].unsqueeze(1))
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T, C)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
states = encoder_model.get_init_states()
|
||||||
|
hidden_states = states[0].unsqueeze(1)
|
||||||
|
cell_states = states[1].unsqueeze(1)
|
||||||
|
# encoder_model = torch.jit.script(encoder_model)
|
||||||
|
# It throws the following error for the above statement
|
||||||
|
#
|
||||||
|
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||||
|
# 11 is not supported. Please feel free to request support or
|
||||||
|
# submit a pull request on PyTorch GitHub.
|
||||||
|
#
|
||||||
|
# I cannot find which statement causes the above error.
|
||||||
|
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||||
|
# works well for the current reworked model
|
||||||
|
warmup = 1.0
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, x_lens, (hidden_states, cell_states), warmup),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens", "hidden_states", "cell_states", "warmup"],
|
||||||
|
output_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"encoder_out_lens",
|
||||||
|
"new_hidden_states",
|
||||||
|
"new_cell_states",
|
||||||
|
],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"hidden_states": {1: "N"},
|
||||||
|
"cell_states": {1: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
"new_hidden_states": {1: "N"},
|
||||||
|
"new_cell_states": {1: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||||
|
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||||
|
# in this case
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
(y, need_pad),
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y", "need_pad"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- joiner_out: a tensor of shape (N, vocab_size)
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
project_input = True
|
||||||
|
# Note: It uses torch.jit.trace() internally
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(encoder_out, decoder_out, project_input),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["encoder_out", "decoder_out", "project_input"],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_all_in_one_onnx(
|
||||||
|
encoder_filename: str,
|
||||||
|
decoder_filename: str,
|
||||||
|
joiner_filename: str,
|
||||||
|
all_in_one_filename: str,
|
||||||
|
):
|
||||||
|
encoder_onnx = onnx.load(encoder_filename)
|
||||||
|
decoder_onnx = onnx.load(decoder_filename)
|
||||||
|
joiner_onnx = onnx.load(joiner_filename)
|
||||||
|
|
||||||
|
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||||
|
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||||
|
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||||
|
|
||||||
|
combined_model = onnx.compose.merge_models(
|
||||||
|
encoder_onnx, decoder_onnx, io_map={}
|
||||||
|
)
|
||||||
|
combined_model = onnx.compose.merge_models(
|
||||||
|
combined_model, joiner_onnx, io_map={}
|
||||||
|
)
|
||||||
|
onnx.save(combined_model, all_in_one_filename)
|
||||||
|
logging.info(f"Saved to {all_in_one_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
if params.onnx is True:
|
||||||
|
opset_version = 11
|
||||||
|
logging.info("Exporting to onnx format")
|
||||||
|
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
model.encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
model.decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
model.joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||||
|
export_all_in_one_onnx(
|
||||||
|
encoder_filename,
|
||||||
|
decoder_filename,
|
||||||
|
joiner_filename,
|
||||||
|
all_in_one_filename,
|
||||||
|
)
|
||||||
|
elif params.jit_trace is True:
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
else:
|
||||||
|
logging.info("Not using torchscript")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
319
egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
Executable file
319
egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
Executable file
@ -0,0 +1,319 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads torchscript models, either exported by `torch.jit.trace()`
|
||||||
|
or by `torch.jit.script()`, and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./lstm_transducer_stateless/export.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./lstm_transducer_stateless/jit_pretrained.py \
|
||||||
|
--encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: torch.jit.ScriptModule,
|
||||||
|
joiner: torch.jit.ScriptModule,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
context_size: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
decoder:
|
||||||
|
The decoder model.
|
||||||
|
joiner:
|
||||||
|
The joiner model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,).
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = encoder_out.device
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = packed_encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
)
|
||||||
|
decoder_out = decoder_out.squeeze(1)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
encoder = torch.jit.load(args.encoder_model_filename)
|
||||||
|
decoder = torch.jit.load(args.decoder_model_filename)
|
||||||
|
joiner = torch.jit.load(args.joiner_model_filename)
|
||||||
|
|
||||||
|
encoder.eval()
|
||||||
|
decoder.eval()
|
||||||
|
joiner.eval()
|
||||||
|
|
||||||
|
encoder.to(device)
|
||||||
|
decoder.to(device)
|
||||||
|
joiner.to(device)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens, _ = encoder(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_size=args.context_size,
|
||||||
|
)
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
words = sp.decode(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -51,7 +51,8 @@ class RNN(EncoderInterface):
|
|||||||
Dropout value for model-level warmup (default=0.075).
|
Dropout value for model-level warmup (default=0.075).
|
||||||
aux_layer_period (int):
|
aux_layer_period (int):
|
||||||
Peroid of auxiliary layers used for randomly combined during training.
|
Peroid of auxiliary layers used for randomly combined during training.
|
||||||
If not larger than 0, will not use the random combiner.
|
If set to 0, will not use the random combiner (Default).
|
||||||
|
You can set a positive integer to use the random combiner, e.g., 3.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -64,7 +65,7 @@ class RNN(EncoderInterface):
|
|||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
aux_layer_period: int = 3,
|
aux_layer_period: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RNN, self).__init__()
|
super(RNN, self).__init__()
|
||||||
|
|
||||||
@ -106,62 +107,11 @@ class RNN(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
The input tensor. Its shape is (N, T, C), where N is the batch size,
|
|
||||||
T is the sequence length, C is the feature dimension.
|
|
||||||
x_lens:
|
|
||||||
A tensor of shape (N,), containing the number of frames in `x`
|
|
||||||
before padding.
|
|
||||||
warmup:
|
|
||||||
A floating point value that gradually increases from 0 throughout
|
|
||||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
|
||||||
to turn modules on sequentially.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of 2 tensors:
|
|
||||||
- embeddings: its shape is (N, T', d_model), where T' is the output
|
|
||||||
sequence lengths.
|
|
||||||
- lengths: a tensor of shape (batch_size,) containing the number of
|
|
||||||
frames in `embeddings` before padding.
|
|
||||||
"""
|
|
||||||
x = self.encoder_embed(x)
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
||||||
|
|
||||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
|
||||||
#
|
|
||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
|
||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
|
||||||
assert x.size(0) == lengths.max().item()
|
|
||||||
|
|
||||||
x = self.encoder(x, warmup)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
|
||||||
return x, lengths
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def get_init_states(
|
|
||||||
self, device: torch.device
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Get model initial states."""
|
|
||||||
# for rnn hidden states
|
|
||||||
hidden_states = torch.zeros(
|
|
||||||
(self.num_encoder_layers, self.d_model), device=device
|
|
||||||
)
|
|
||||||
cell_states = torch.zeros(
|
|
||||||
(self.num_encoder_layers, self.rnn_hidden_size), device=device
|
|
||||||
)
|
|
||||||
return (hidden_states, cell_states)
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def infer(
|
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
states: Tuple[torch.Tensor, torch.Tensor],
|
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -172,11 +122,15 @@ class RNN(EncoderInterface):
|
|||||||
A tensor of shape (N,), containing the number of frames in `x`
|
A tensor of shape (N,), containing the number of frames in `x`
|
||||||
before padding.
|
before padding.
|
||||||
states:
|
states:
|
||||||
It is a list of 2 tensors.
|
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||||
states[0] is the hidden states of all layers,
|
states[0] is the hidden states of all layers,
|
||||||
with shape of (num_layers, N, d_model);
|
with shape of (num_layers, N, d_model);
|
||||||
states[1] is the cell states of all layers,
|
states[1] is the cell states of all layers,
|
||||||
with shape of (num_layers, N, rnn_hidden_size).
|
with shape of (num_layers, N, rnn_hidden_size).
|
||||||
|
warmup:
|
||||||
|
A floating point value that gradually increases from 0 throughout
|
||||||
|
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||||
|
to turn modules on sequentially.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of 3 tensors:
|
A tuple of 3 tensors:
|
||||||
@ -186,36 +140,57 @@ class RNN(EncoderInterface):
|
|||||||
frames in `embeddings` before padding.
|
frames in `embeddings` before padding.
|
||||||
- updated states, whose shape is same as the input states.
|
- updated states, whose shape is same as the input states.
|
||||||
"""
|
"""
|
||||||
assert not self.training
|
x = self.encoder_embed(x)
|
||||||
assert len(states) == 2
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
# for hidden state
|
|
||||||
assert states[0].shape == (
|
|
||||||
self.num_encoder_layers,
|
|
||||||
x.size(0),
|
|
||||||
self.d_model,
|
|
||||||
)
|
|
||||||
# for cell state
|
|
||||||
assert states[1].shape == (
|
|
||||||
self.num_encoder_layers,
|
|
||||||
x.size(0),
|
|
||||||
self.rnn_hidden_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||||
#
|
#
|
||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
# we will cut off 1 frame on each side of encoder_embed output
|
if not torch.jit.is_tracing():
|
||||||
lengths -= 2
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
embed = self.encoder_embed(x)
|
if states is None:
|
||||||
embed = embed[:, 1:-1, :]
|
x = self.encoder(x, warmup=warmup)[0]
|
||||||
embed = embed.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
# torch.jit.trace requires returned types be the same as annotated
|
||||||
|
new_states = (torch.empty(0), torch.empty(0))
|
||||||
|
else:
|
||||||
|
# we cut off 1 frame on each side of encoder_embed output
|
||||||
|
lengths -= 2
|
||||||
|
x = x[1:-1, :, :]
|
||||||
|
|
||||||
x, states = self.encoder.infer(embed, states)
|
assert not self.training
|
||||||
|
assert len(states) == 2
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
|
# for hidden state
|
||||||
|
assert states[0].shape == (
|
||||||
|
self.num_encoder_layers,
|
||||||
|
x.size(1),
|
||||||
|
self.d_model,
|
||||||
|
)
|
||||||
|
# for cell state
|
||||||
|
assert states[1].shape == (
|
||||||
|
self.num_encoder_layers,
|
||||||
|
x.size(1),
|
||||||
|
self.rnn_hidden_size,
|
||||||
|
)
|
||||||
|
x, new_states = self.encoder(x, states)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||||
return x, lengths, states
|
return x, lengths, new_states
|
||||||
|
|
||||||
|
def get_init_states(
|
||||||
|
self, device: torch.device = torch.device("cpu")
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Get model initial states."""
|
||||||
|
# for rnn hidden states
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
(self.num_encoder_layers, self.d_model), device=device
|
||||||
|
)
|
||||||
|
cell_states = torch.zeros(
|
||||||
|
(self.num_encoder_layers, self.rnn_hidden_size), device=device
|
||||||
|
)
|
||||||
|
return (hidden_states, cell_states)
|
||||||
|
|
||||||
|
|
||||||
class RNNEncoderLayer(nn.Module):
|
class RNNEncoderLayer(nn.Module):
|
||||||
@ -271,7 +246,12 @@ class RNNEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
src: torch.Tensor,
|
||||||
|
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
@ -280,6 +260,12 @@ class RNNEncoderLayer(nn.Module):
|
|||||||
The sequence to the encoder layer (required).
|
The sequence to the encoder layer (required).
|
||||||
Its shape is (S, N, E), where S is the sequence length,
|
Its shape is (S, N, E), where S is the sequence length,
|
||||||
N is the batch size, and E is the feature number.
|
N is the batch size, and E is the feature number.
|
||||||
|
states:
|
||||||
|
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||||
|
states[0] is the hidden states of all layers,
|
||||||
|
with shape of (1, N, d_model);
|
||||||
|
states[1] is the cell states of all layers,
|
||||||
|
with shape of (1, N, rnn_hidden_size).
|
||||||
warmup:
|
warmup:
|
||||||
It controls selective bypass of of layers; if < 1.0, we will
|
It controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
@ -299,7 +285,19 @@ class RNNEncoderLayer(nn.Module):
|
|||||||
alpha = 1.0
|
alpha = 1.0
|
||||||
|
|
||||||
# lstm module
|
# lstm module
|
||||||
src_lstm = self.lstm(src)[0]
|
if states is None:
|
||||||
|
src_lstm = self.lstm(src)[0]
|
||||||
|
# torch.jit.trace requires returned types be the same as annotated
|
||||||
|
new_states = (torch.empty(0), torch.empty(0))
|
||||||
|
else:
|
||||||
|
assert not self.training
|
||||||
|
assert len(states) == 2
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
|
# for hidden state
|
||||||
|
assert states[0].shape == (1, src.size(1), self.d_model)
|
||||||
|
# for cell state
|
||||||
|
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
||||||
|
src_lstm, new_states = self.lstm(src, states)
|
||||||
src = src + self.dropout(src_lstm)
|
src = src + self.dropout(src_lstm)
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
@ -310,41 +308,6 @@ class RNNEncoderLayer(nn.Module):
|
|||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
src = alpha * src + (1 - alpha) * src_orig
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
return src
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def infer(
|
|
||||||
self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Pass the input through the encoder layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src:
|
|
||||||
The sequence to the encoder layer (required).
|
|
||||||
Its shape is (S, N, d_model), where S is the sequence length,
|
|
||||||
N is the batch size.
|
|
||||||
states:
|
|
||||||
It is a tuple of 2 tensors.
|
|
||||||
states[0] is the hidden state, with shape of (1, N, d_model);
|
|
||||||
states[1] is the cell state, with shape of (1, N, rnn_hidden_size).
|
|
||||||
"""
|
|
||||||
assert not self.training
|
|
||||||
assert len(states) == 2
|
|
||||||
# for hidden state
|
|
||||||
assert states[0].shape == (1, src.size(1), self.d_model)
|
|
||||||
# for cell state
|
|
||||||
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
|
||||||
|
|
||||||
# lstm module
|
|
||||||
src_lstm, new_states = self.lstm(src, states)
|
|
||||||
src = src + self.dropout(src_lstm)
|
|
||||||
|
|
||||||
# feed forward module
|
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
|
||||||
|
|
||||||
return src, new_states
|
return src, new_states
|
||||||
|
|
||||||
|
|
||||||
@ -373,11 +336,11 @@ class RNNEncoder(nn.Module):
|
|||||||
self.d_model = encoder_layer.d_model
|
self.d_model = encoder_layer.d_model
|
||||||
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
|
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
|
||||||
|
|
||||||
self.use_random_combiner = False
|
self.aux_layers: List[int] = []
|
||||||
|
self.combiner: Optional[nn.Module] = None
|
||||||
if aux_layers is not None:
|
if aux_layers is not None:
|
||||||
assert len(set(aux_layers)) == len(aux_layers)
|
assert len(set(aux_layers)) == len(aux_layers)
|
||||||
assert num_layers - 1 not in aux_layers
|
assert num_layers - 1 not in aux_layers
|
||||||
self.use_random_combiner = True
|
|
||||||
self.aux_layers = aux_layers + [num_layers - 1]
|
self.aux_layers = aux_layers + [num_layers - 1]
|
||||||
self.combiner = RandomCombine(
|
self.combiner = RandomCombine(
|
||||||
num_inputs=len(self.aux_layers),
|
num_inputs=len(self.aux_layers),
|
||||||
@ -386,7 +349,12 @@ class RNNEncoder(nn.Module):
|
|||||||
stddev=2.0,
|
stddev=2.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
src: torch.Tensor,
|
||||||
|
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer in turn.
|
Pass the input through the encoder layer in turn.
|
||||||
|
|
||||||
@ -395,75 +363,66 @@ class RNNEncoder(nn.Module):
|
|||||||
The sequence to the encoder layer (required).
|
The sequence to the encoder layer (required).
|
||||||
Its shape is (S, N, E), where S is the sequence length,
|
Its shape is (S, N, E), where S is the sequence length,
|
||||||
N is the batch size, and E is the feature number.
|
N is the batch size, and E is the feature number.
|
||||||
warmup:
|
|
||||||
It controls selective bypass of of layers; if < 1.0, we will
|
|
||||||
bypass layers more frequently.
|
|
||||||
"""
|
|
||||||
output = src
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
|
||||||
output = mod(output, warmup=warmup)
|
|
||||||
if self.use_random_combiner:
|
|
||||||
if i in self.aux_layers:
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
if self.use_random_combiner:
|
|
||||||
output = self.combiner(outputs)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def infer(
|
|
||||||
self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Pass the input through the encoder layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src:
|
|
||||||
The sequence to the encoder layer (required).
|
|
||||||
Its shape is (S, N, d_model), where S is the sequence length,
|
|
||||||
N is the batch size.
|
|
||||||
states:
|
states:
|
||||||
It is a list of 2 tensors.
|
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||||
states[0] is the hidden states of all layers,
|
states[0] is the hidden states of all layers,
|
||||||
with shape of (num_layers, N, d_model);
|
with shape of (num_layers, N, d_model);
|
||||||
states[1] is the cell states of all layers,
|
states[1] is the cell states of all layers,
|
||||||
with shape of (num_layers, N, rnn_hidden_size).
|
with shape of (num_layers, N, rnn_hidden_size).
|
||||||
|
warmup:
|
||||||
|
It controls selective bypass of of layers; if < 1.0, we will
|
||||||
|
bypass layers more frequently.
|
||||||
"""
|
"""
|
||||||
assert not self.training
|
if states is not None:
|
||||||
assert len(states) == 2
|
assert not self.training
|
||||||
# for hidden state
|
assert len(states) == 2
|
||||||
assert states[0].shape == (self.num_layers, src.size(1), self.d_model)
|
if not torch.jit.is_tracing():
|
||||||
# for cell state
|
# for hidden state
|
||||||
assert states[1].shape == (
|
assert states[0].shape == (
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
src.size(1),
|
src.size(1),
|
||||||
self.rnn_hidden_size,
|
self.d_model,
|
||||||
)
|
)
|
||||||
|
# for cell state
|
||||||
|
assert states[1].shape == (
|
||||||
|
self.num_layers,
|
||||||
|
src.size(1),
|
||||||
|
self.rnn_hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
|
||||||
new_hidden_states = []
|
new_hidden_states = []
|
||||||
new_cell_states = []
|
new_cell_states = []
|
||||||
for layer_index, mod in enumerate(self.layers):
|
|
||||||
layer_states = (
|
|
||||||
states[0][
|
|
||||||
layer_index : layer_index + 1, :, :
|
|
||||||
], # h: (1, N, d_model)
|
|
||||||
states[1][
|
|
||||||
layer_index : layer_index + 1, :, :
|
|
||||||
], # c: (1, N, rnn_hidden_size)
|
|
||||||
)
|
|
||||||
output, (h, c) = mod.infer(output, layer_states)
|
|
||||||
new_hidden_states.append(h)
|
|
||||||
new_cell_states.append(c)
|
|
||||||
|
|
||||||
new_states = (
|
for i, mod in enumerate(self.layers):
|
||||||
torch.cat(new_hidden_states, dim=0),
|
if states is None:
|
||||||
torch.cat(new_cell_states, dim=0),
|
output = mod(output, warmup=warmup)[0]
|
||||||
)
|
else:
|
||||||
|
layer_state = (
|
||||||
|
states[0][i : i + 1, :, :], # h: (1, N, d_model)
|
||||||
|
states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size)
|
||||||
|
)
|
||||||
|
output, (h, c) = mod(output, layer_state)
|
||||||
|
new_hidden_states.append(h)
|
||||||
|
new_cell_states.append(c)
|
||||||
|
|
||||||
|
if self.combiner is not None and i in self.aux_layers:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
if self.combiner is not None:
|
||||||
|
output = self.combiner(outputs)
|
||||||
|
|
||||||
|
if states is None:
|
||||||
|
new_states = (torch.empty(0), torch.empty(0))
|
||||||
|
else:
|
||||||
|
new_states = (
|
||||||
|
torch.cat(new_hidden_states, dim=0),
|
||||||
|
torch.cat(new_cell_states, dim=0),
|
||||||
|
)
|
||||||
|
|
||||||
return output, new_states
|
return output, new_states
|
||||||
|
|
||||||
|
|
||||||
@ -804,9 +763,9 @@ if __name__ == "__main__":
|
|||||||
m = RNN(
|
m = RNN(
|
||||||
num_features=feature_dim,
|
num_features=feature_dim,
|
||||||
d_model=512,
|
d_model=512,
|
||||||
rnn_hidden_size=1024,
|
rnn_hidden_size=1536,
|
||||||
dim_feedforward=2048,
|
dim_feedforward=2048,
|
||||||
num_encoder_layers=12,
|
num_encoder_layers=10,
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
To run this file, do:
|
To run this file, do:
|
||||||
|
|
||||||
cd icefall/egs/librispeech/ASR
|
cd icefall/egs/librispeech/ASR
|
||||||
python ./pruned_transducer_stateless3/test_scaling_converter.py
|
python ./lstm_transducer_stateless/test_scaling_converter.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
@ -389,8 +389,9 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
initial_speed: float = 1.0,
|
initial_speed: float = 1.0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# Hardcode bidirectional=False
|
if "bidirectional" in kwargs:
|
||||||
super(ScaledLSTM, self).__init__(*args, bidirectional=False, **kwargs)
|
assert kwargs["bidirectional"] is False
|
||||||
|
super(ScaledLSTM, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self._scales_names = []
|
self._scales_names = []
|
||||||
self._scales = []
|
self._scales = []
|
||||||
|
@ -170,7 +170,7 @@ def scaled_embedding_to_embedding(
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM):
|
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
|
||||||
"""Convert an instance of ScaledLSTM to nn.LSTM.
|
"""Convert an instance of ScaledLSTM to nn.LSTM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user