support jit trace

This commit is contained in:
yaozengwei 2022-08-09 19:29:21 +08:00
parent 03b056ca37
commit 45c7894111
6 changed files with 1041 additions and 181 deletions

View File

@ -0,0 +1,581 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.trace()
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit-trace 1
It will generates 3 files: `encoder_jit_trace.pt`,
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
(3) Export to ONNX format
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
It will generate the following three files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
(4) Export `model.state_dict()`
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `lstm_transducer_stateless/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./lstm_transducer_stateless/decode.py \
--exp-dir ./lstm_transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
"""
import argparse
import logging
from pathlib import Path
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit-trace",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. Three files will be generated:
- encoder.onnx
- decoder.onnx
- joiner.onnx
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
states = (states[0].unsqueeze(1), states[1].unsqueeze(1))
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
hidden_states = states[0].unsqueeze(1)
cell_states = states[1].unsqueeze(1)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
warmup = 1.0
torch.onnx.export(
encoder_model,
(x, x_lens, (hidden_states, cell_states), warmup),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens", "hidden_states", "cell_states", "warmup"],
output_names=[
"encoder_out",
"encoder_out_lens",
"new_hidden_states",
"new_cell_states",
],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"hidden_states": {1: "N"},
"cell_states": {1: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
"new_hidden_states": {1: "N"},
"new_cell_states": {1: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has two inputs:
- encoder_out: a tensor of shape (N, encoder_out_dim)
- decoder_out: a tensor of shape (N, decoder_out_dim)
and has one output:
- joiner_out: a tensor of shape (N, vocab_size)
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
project_input = True
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out", "project_input"],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
def export_all_in_one_onnx(
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
all_in_one_filename: str,
):
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
if params.onnx is True:
opset_version = 11
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
encoder_filename,
decoder_filename,
joiner_filename,
all_in_one_filename,
)
elif params.jit_trace is True:
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
else:
logging.info("Not using torchscript")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,319 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, either exported by `torch.jit.trace()`
or by `torch.jit.script()`, and uses them to decode waves.
You can use the following command to get the exported models:
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit-trace 1
Usage of this script:
./lstm_transducer_stateless/jit_pretrained.py \
--encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens, _ = encoder(
x=features,
x_lens=feature_lengths,
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -51,7 +51,8 @@ class RNN(EncoderInterface):
Dropout value for model-level warmup (default=0.075). Dropout value for model-level warmup (default=0.075).
aux_layer_period (int): aux_layer_period (int):
Peroid of auxiliary layers used for randomly combined during training. Peroid of auxiliary layers used for randomly combined during training.
If not larger than 0, will not use the random combiner. If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
""" """
def __init__( def __init__(
@ -64,7 +65,7 @@ class RNN(EncoderInterface):
num_encoder_layers: int = 12, num_encoder_layers: int = 12,
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
aux_layer_period: int = 3, aux_layer_period: int = 0,
) -> None: ) -> None:
super(RNN, self).__init__() super(RNN, self).__init__()
@ -106,62 +107,11 @@ class RNN(EncoderInterface):
) )
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C), where N is the batch size,
T is the sequence length, C is the feature dimension.
x_lens:
A tensor of shape (N,), containing the number of frames in `x`
before padding.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
A tuple of 2 tensors:
- embeddings: its shape is (N, T', d_model), where T' is the output
sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
x = self.encoder(x, warmup)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths
@torch.jit.export
def get_init_states(
self, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get model initial states."""
# for rnn hidden states
hidden_states = torch.zeros(
(self.num_encoder_layers, self.d_model), device=device
)
cell_states = torch.zeros(
(self.num_encoder_layers, self.rnn_hidden_size), device=device
)
return (hidden_states, cell_states)
@torch.jit.export
def infer(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
states: Tuple[torch.Tensor, torch.Tensor], states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
Args: Args:
@ -172,11 +122,15 @@ class RNN(EncoderInterface):
A tensor of shape (N,), containing the number of frames in `x` A tensor of shape (N,), containing the number of frames in `x`
before padding. before padding.
states: states:
It is a list of 2 tensors. A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers, states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model); with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers, states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size). with shape of (num_layers, N, rnn_hidden_size).
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns: Returns:
A tuple of 3 tensors: A tuple of 3 tensors:
@ -186,36 +140,57 @@ class RNN(EncoderInterface):
frames in `embeddings` before padding. frames in `embeddings` before padding.
- updated states, whose shape is same as the input states. - updated states, whose shape is same as the input states.
""" """
assert not self.training x = self.encoder_embed(x)
assert len(states) == 2 x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# for hidden state
assert states[0].shape == (
self.num_encoder_layers,
x.size(0),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_encoder_layers,
x.size(0),
self.rnn_hidden_size,
)
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
# #
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1 lengths = (((x_lens - 1) >> 1) - 1) >> 1
# we will cut off 1 frame on each side of encoder_embed output if not torch.jit.is_tracing():
lengths -= 2 assert x.size(0) == lengths.max().item()
embed = self.encoder_embed(x) if states is None:
embed = embed[:, 1:-1, :] x = self.encoder(x, warmup=warmup)[0]
embed = embed.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # torch.jit.trace requires returned types be the same as annotated
new_states = (torch.empty(0), torch.empty(0))
else:
# we cut off 1 frame on each side of encoder_embed output
lengths -= 2
x = x[1:-1, :, :]
x, states = self.encoder.infer(embed, states) assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_encoder_layers,
x.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_encoder_layers,
x.size(1),
self.rnn_hidden_size,
)
x, new_states = self.encoder(x, states)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths, states return x, lengths, new_states
def get_init_states(
self, device: torch.device = torch.device("cpu")
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get model initial states."""
# for rnn hidden states
hidden_states = torch.zeros(
(self.num_encoder_layers, self.d_model), device=device
)
cell_states = torch.zeros(
(self.num_encoder_layers, self.rnn_hidden_size), device=device
)
return (hidden_states, cell_states)
class RNNEncoderLayer(nn.Module): class RNNEncoderLayer(nn.Module):
@ -271,7 +246,12 @@ class RNNEncoderLayer(nn.Module):
) )
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor: def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -280,6 +260,12 @@ class RNNEncoderLayer(nn.Module):
The sequence to the encoder layer (required). The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length, Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number. N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (1, N, d_model);
states[1] is the cell states of all layers,
with shape of (1, N, rnn_hidden_size).
warmup: warmup:
It controls selective bypass of of layers; if < 1.0, we will It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
@ -299,7 +285,19 @@ class RNNEncoderLayer(nn.Module):
alpha = 1.0 alpha = 1.0
# lstm module # lstm module
src_lstm = self.lstm(src)[0] if states is None:
src_lstm = self.lstm(src)[0]
# torch.jit.trace requires returned types be the same as annotated
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (1, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm) src = src + self.dropout(src_lstm)
# feed forward module # feed forward module
@ -310,41 +308,6 @@ class RNNEncoderLayer(nn.Module):
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
return src
@torch.jit.export
def infer(
self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, d_model), where S is the sequence length,
N is the batch size.
states:
It is a tuple of 2 tensors.
states[0] is the hidden state, with shape of (1, N, d_model);
states[1] is the cell state, with shape of (1, N, rnn_hidden_size).
"""
assert not self.training
assert len(states) == 2
# for hidden state
assert states[0].shape == (1, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
# lstm module
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
return src, new_states return src, new_states
@ -373,11 +336,11 @@ class RNNEncoder(nn.Module):
self.d_model = encoder_layer.d_model self.d_model = encoder_layer.d_model
self.rnn_hidden_size = encoder_layer.rnn_hidden_size self.rnn_hidden_size = encoder_layer.rnn_hidden_size
self.use_random_combiner = False self.aux_layers: List[int] = []
self.combiner: Optional[nn.Module] = None
if aux_layers is not None: if aux_layers is not None:
assert len(set(aux_layers)) == len(aux_layers) assert len(set(aux_layers)) == len(aux_layers)
assert num_layers - 1 not in aux_layers assert num_layers - 1 not in aux_layers
self.use_random_combiner = True
self.aux_layers = aux_layers + [num_layers - 1] self.aux_layers = aux_layers + [num_layers - 1]
self.combiner = RandomCombine( self.combiner = RandomCombine(
num_inputs=len(self.aux_layers), num_inputs=len(self.aux_layers),
@ -386,7 +349,12 @@ class RNNEncoder(nn.Module):
stddev=2.0, stddev=2.0,
) )
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor: def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
Pass the input through the encoder layer in turn. Pass the input through the encoder layer in turn.
@ -395,75 +363,66 @@ class RNNEncoder(nn.Module):
The sequence to the encoder layer (required). The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length, Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number. N is the batch size, and E is the feature number.
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
output = src
outputs = []
for i, mod in enumerate(self.layers):
output = mod(output, warmup=warmup)
if self.use_random_combiner:
if i in self.aux_layers:
outputs.append(output)
if self.use_random_combiner:
output = self.combiner(outputs)
return output
@torch.jit.export
def infer(
self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, d_model), where S is the sequence length,
N is the batch size.
states: states:
It is a list of 2 tensors. A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers, states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model); with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers, states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size). with shape of (num_layers, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
""" """
assert not self.training if states is not None:
assert len(states) == 2 assert not self.training
# for hidden state assert len(states) == 2
assert states[0].shape == (self.num_layers, src.size(1), self.d_model) if not torch.jit.is_tracing():
# for cell state # for hidden state
assert states[1].shape == ( assert states[0].shape == (
self.num_layers, self.num_layers,
src.size(1), src.size(1),
self.rnn_hidden_size, self.d_model,
) )
# for cell state
assert states[1].shape == (
self.num_layers,
src.size(1),
self.rnn_hidden_size,
)
output = src output = src
outputs = []
new_hidden_states = [] new_hidden_states = []
new_cell_states = [] new_cell_states = []
for layer_index, mod in enumerate(self.layers):
layer_states = (
states[0][
layer_index : layer_index + 1, :, :
], # h: (1, N, d_model)
states[1][
layer_index : layer_index + 1, :, :
], # c: (1, N, rnn_hidden_size)
)
output, (h, c) = mod.infer(output, layer_states)
new_hidden_states.append(h)
new_cell_states.append(c)
new_states = ( for i, mod in enumerate(self.layers):
torch.cat(new_hidden_states, dim=0), if states is None:
torch.cat(new_cell_states, dim=0), output = mod(output, warmup=warmup)[0]
) else:
layer_state = (
states[0][i : i + 1, :, :], # h: (1, N, d_model)
states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size)
)
output, (h, c) = mod(output, layer_state)
new_hidden_states.append(h)
new_cell_states.append(c)
if self.combiner is not None and i in self.aux_layers:
outputs.append(output)
if self.combiner is not None:
output = self.combiner(outputs)
if states is None:
new_states = (torch.empty(0), torch.empty(0))
else:
new_states = (
torch.cat(new_hidden_states, dim=0),
torch.cat(new_cell_states, dim=0),
)
return output, new_states return output, new_states
@ -804,9 +763,9 @@ if __name__ == "__main__":
m = RNN( m = RNN(
num_features=feature_dim, num_features=feature_dim,
d_model=512, d_model=512,
rnn_hidden_size=1024, rnn_hidden_size=1536,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=10,
) )
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20

View File

@ -19,7 +19,7 @@
To run this file, do: To run this file, do:
cd icefall/egs/librispeech/ASR cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless3/test_scaling_converter.py python ./lstm_transducer_stateless/test_scaling_converter.py
""" """
import copy import copy

View File

@ -389,8 +389,9 @@ class ScaledLSTM(nn.LSTM):
initial_speed: float = 1.0, initial_speed: float = 1.0,
**kwargs **kwargs
): ):
# Hardcode bidirectional=False if "bidirectional" in kwargs:
super(ScaledLSTM, self).__init__(*args, bidirectional=False, **kwargs) assert kwargs["bidirectional"] is False
super(ScaledLSTM, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self._scales_names = [] self._scales_names = []
self._scales = [] self._scales = []

View File

@ -170,7 +170,7 @@ def scaled_embedding_to_embedding(
return embedding return embedding
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM): def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
"""Convert an instance of ScaledLSTM to nn.LSTM. """Convert an instance of ScaledLSTM to nn.LSTM.
Args: Args: