Replace codes copied from librispeech recipe with symlink

This commit is contained in:
whsqkaak 2024-06-13 17:13:17 +09:00
parent db5c61d371
commit ed9cc836ce
17 changed files with 17 additions and 12268 deletions

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py

View File

@ -1,151 +0,0 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import k2
import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict
class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
cut_id: str,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
initial_states:
Initial decode states of the model, e.g. the return value of
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device:
The device to run this stream.
"""
if params.decoding_method == "fast_beam_search":
assert decoding_graph is not None
assert device == decoding_graph.device
self.params = params
self.cut_id = cut_id
self.LOG_EPS = math.log(1e-10)
self.states = initial_states
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
self.num_frames: int = 0
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
self.ground_truth: str = ""
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
# how many frames have been processed, after subsampling (i.e. a
# cumulative sum of the second return value of
# encoder.streaming_forward
self.done_frames: int = 0
# It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2
# 1) feature embedding: out_lens=(x_lens-7)//2
# 2) output subsampling: out_lens=(out_lens+1)//2
self.pad_length = 7
if params.decoding_method == "greedy_search":
self.hyp = [-1] * (params.context_size - 1) + [params.blank_id]
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[-1] * (params.context_size - 1) + [params.blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
decoding_graph
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
@property
def done(self) -> bool:
"""Return True if all the features are processed."""
return self._done
@property
def id(self) -> str:
return self.cut_id
def set_features(
self,
features: torch.Tensor,
tail_pad_len: int = 0,
) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length + tail_pad_len),
mode="constant",
value=self.LOG_EPS,
)
self.num_frames = self.features.size(0)
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
chunk_length = chunk_size + self.pad_length
ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames + ret_length # noqa
]
self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_features, ret_length
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.params.decoding_method == "greedy_search":
return self.hyp[self.params.context_size :] # noqa
elif self.params.decoding_method == "modified_beam_search":
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.params.context_size :] # noqa
else:
assert self.params.decoding_method == "fast_beam_search"
return self.hyp

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py

View File

@ -1,109 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
if torch.jit.is_tracing():
# This is for exporting to PNNX via ONNX
embedding_out = self.embedding(y)
else:
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
return embedding_out

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py

View File

@ -1,43 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py

View File

@ -1,653 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
"""
This script exports a transducer model from PyTorch to ONNX.
- Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--exp-dir $repo/exp/
It will generate the following 3 files in exp
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py for how to use the exported models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_proj = encoder_proj
def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
"""Please see the help information of Zipformer.streaming_forward"""
N = x.size(0)
T = x.size(1)
x_lens = torch.tensor([T] * N, device=x.device)
output, _, new_states = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=states,
)
output = self.encoder_proj(output)
# Now output is of shape (N, T, joiner_dim)
return output, new_states
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""
Onnx model inputs:
- 0: src
- many state tensors (the exact number depending on the actual model)
Onnx model outputs:
- 0: output, its shape is (N, T, joiner_dim)
- many state tensors (the exact number depending on the actual model)
Args:
encoder_model:
The model to be exported
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
)
decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2
pad_length = 7
T = decode_chunk_len + pad_length
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"pad_length: {pad_length}")
logging.info(f"T: {T}")
x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.encoder.get_init_state()
num_encoders = encoder_model.encoder.num_encoders
logging.info(f"num_encoders: {num_encoders}")
logging.info(f"len(init_state): {len(init_state)}")
inputs = {}
input_names = ["x"]
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(tensors, name, N):
for i, s in enumerate(tensors):
logging.info(f"{name}_{i}.shape: {s.shape}")
inputs[f"{name}_{i}"] = {N: "N"}
outputs[f"new_{name}_{i}"] = {N: "N"}
input_names.append(f"{name}_{i}")
output_names.append(f"new_{name}_{i}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims))
attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims))
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels))
ds = encoder_model.encoder.zipformer_downsampling_factors
left_context_len = encoder_model.encoder.left_context_len
left_context_len = [left_context_len // k for k in ds]
left_context_len = ",".join(map(str, left_context_len))
meta_data = {
"model_type": "zipformer",
"version": "1",
"model_author": "k2-fsa",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 39
"num_encoder_layers": num_encoder_layers,
"encoder_dims": encoder_dims,
"attention_dims": attention_dims,
"cnn_module_kernels": cnn_module_kernels,
"left_context_len": left_context_len,
}
logging.info(f"meta_data: {meta_data}")
# (num_encoder_layers, 1)
cached_len = init_state[num_encoders * 0 : num_encoders * 1]
# (num_encoder_layers, 1, encoder_dim)
cached_avg = init_state[num_encoders * 1 : num_encoders * 2]
# (num_encoder_layers, left_context_len, 1, attention_dim)
cached_key = init_state[num_encoders * 2 : num_encoders * 3]
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
cached_val = init_state[num_encoders * 3 : num_encoders * 4]
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
cached_val2 = init_state[num_encoders * 4 : num_encoders * 5]
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6]
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7]
build_inputs_outputs(cached_len, "cached_len", 1)
build_inputs_outputs(cached_avg, "cached_avg", 1)
build_inputs_outputs(cached_key, "cached_key", 2)
build_inputs_outputs(cached_val, "cached_val", 2)
build_inputs_outputs(cached_val2, "cached_val2", 2)
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
logging.info(inputs)
logging.info(outputs)
logging.info(input_names)
logging.info(output_names)
torch.onnx.export(
encoder_model,
(x, init_state),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"x": {0: "N"},
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
logging.info(f"device: {device}")
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
if params.use_averaged_model:
suffix += "-with-averaged-model"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py

View File

@ -1,872 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/ksponspeech/ASR
./pruned_transducer_stateless7_streaming/decode.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
(3) Export to ONNX format with pretrained.pt
Assume we will export to ONNX format with `epoch-999.pt`.
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
--fp16 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
(4) Export to ONNX format for triton server
Assume we will export to ONNX format with `epoch-999.pt`.
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
--fp16 \
--onnx-triton 1 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
Check
https://github.com/k2-fsa/sherpa/tree/master/triton
for how to use the exported models outside of icefall.
"""
import argparse
import logging
from pathlib import Path
import k2
import onnxruntime
import torch
import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--onnx-triton",
type=str2bool,
default=False,
help="""If True, --onnx would export model into the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton.
""",
)
parser.add_argument(
"--fp16",
action="store_true",
help="whether to export fp16 onnx model, default false",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
for a, b in zip(xlist, blist):
try:
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
except AssertionError as error:
if tolerate_small_mismatch:
print("small mismatch detected", error)
else:
return False
return True
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
batch_size = 17
seq_len = 101
torch.manual_seed(0)
x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32)
x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
initial_states = [encoder_model.get_init_state() for _ in range(batch_size)]
states = stack_states(initial_states)
left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks
encoder_attention_dim = encoder_model.encoders[0].attention_dim
len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15
avg_cache = torch.cat(
states[encoder_model.num_encoders : 2 * encoder_model.num_encoders]
).transpose(
0, 1
) # [B,15,384]
cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose(
0, 1
) # [B,2*15,384,cnn_kernel-1]
pad_tensors = [
torch.nn.functional.pad(
tensor,
(
0,
encoder_attention_dim - tensor.shape[-1],
0,
0,
0,
left_context_len - tensor.shape[1],
0,
0,
),
)
for tensor in states[
2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders
]
]
attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
encoder_model_wrapper = OnnxStreamingEncoder(encoder_model)
torch.onnx.export(
encoder_model_wrapper,
(x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"x",
"x_lens",
"len_cache",
"avg_cache",
"attn_cache",
"cnn_cache",
],
output_names=[
"encoder_out",
"encoder_out_lens",
"new_len_cache",
"new_avg_cache",
"new_attn_cache",
"new_cnn_cache",
],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
"len_cache": {0: "N"},
"avg_cache": {0: "N"},
"attn_cache": {0: "N"},
"cnn_cache": {0: "N"},
"new_len_cache": {0: "N"},
"new_avg_cache": {0: "N"},
"new_attn_cache": {0: "N"},
"new_cnn_cache": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
# Test onnx encoder with torch native encoder
encoder_model.eval()
(
encoder_out_torch,
encoder_out_lens_torch,
new_states_torch,
) = encoder_model.streaming_forward(
x=x,
x_lens=x_lens,
states=states,
)
ort_session = onnxruntime.InferenceSession(
str(encoder_filename), providers=["CPUExecutionProvider"]
)
ort_inputs = {
"x": x.numpy(),
"x_lens": x_lens.numpy(),
"len_cache": len_cache.numpy(),
"avg_cache": avg_cache.numpy(),
"attn_cache": attn_cache.numpy(),
"cnn_cache": cnn_cache.numpy(),
}
ort_outs = ort_session.run(None, ort_inputs)
assert test_acc(
[encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2]
)
logging.info(f"{encoder_filename} acc test succeeded.")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_decoder_model_onnx_triton(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
decoder_model = TritonOnnxDecoder(decoder_model)
torch.onnx.export(
decoder_model,
(y,),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
def export_joiner_model_onnx_triton(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has two inputs:
- encoder_out: a tensor of shape (N, encoder_out_dim)
- decoder_out: a tensor of shape (N, decoder_out_dim)
and has one output:
- joiner_out: a tensor of shape (N, vocab_size)
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
joiner_model = TritonOnnxJoiner(joiner_model)
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out"],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.onnx:
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 13
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
if not params.onnx_triton:
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
else:
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx_triton(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx_triton(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
if params.fp16:
try:
import onnxmltools
from onnxmltools.utils.float16_converter import convert_float_to_float16
except ImportError:
print("Please install onnxmltools!")
import sys
sys.exit(1)
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model)
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx"
export_onnx_fp16(encoder_filename, encoder_fp16_filename)
decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx"
export_onnx_fp16(decoder_filename, decoder_fp16_filename)
joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx"
export_onnx_fp16(joiner_filename, joiner_fp16_filename)
if not params.onnx_triton:
encoder_proj_filename = str(joiner_filename).replace(
".onnx", "_encoder_proj.onnx"
)
encoder_proj_fp16_filename = (
params.exp_dir / "joiner_encoder_proj_fp16.onnx"
)
export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename)
decoder_proj_filename = str(joiner_filename).replace(
".onnx", "_decoder_proj.onnx"
)
decoder_proj_fp16_filename = (
params.exp_dir / "joiner_decoder_proj_fp16.onnx"
)
export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename)
elif params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward
model.encoder.__class__.non_streaming_forward = torch.jit.export(
model.encoder.__class__.non_streaming_forward
)
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py

View File

@ -1,64 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py

View File

@ -1,198 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = nn.Linear(
encoder_dim,
vocab_size,
)
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
# x.T_dim == max(x_len)
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max())
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py

View File

@ -1,241 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
"""
This script checks that exported ONNX models produce the same output
with the given torchscript model for the same input.
1. Export the model via torch.jit.trace()
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
2. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file
./pruned_transducer_stateless7_streaming/onnx_check.py \
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
"""
import argparse
import logging
import torch
from onnx_pretrained import OnnxModel
from zipformer import stack_states
from icefall import is_module_available
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-encoder-filename",
required=True,
type=str,
help="Path to the torchscript encoder model",
)
parser.add_argument(
"--jit-decoder-filename",
required=True,
type=str,
help="Path to the torchscript decoder model",
)
parser.add_argument(
"--jit-joiner-filename",
required=True,
type=str,
help="Path to the torchscript joiner model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the ONNX encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the ONNX decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the ONNX joiner model",
)
return parser
def test_encoder(
torch_encoder_model: torch.jit.ScriptModule,
torch_encoder_proj_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
N = torch.randint(1, 100, size=(1,)).item()
T = onnx_model.segment
C = 80
x_lens = torch.tensor([T] * N)
torch_states = [torch_encoder_model.get_init_state() for _ in range(N)]
torch_states = stack_states(torch_states)
onnx_model.init_encoder_states(N)
for i in range(5):
logging.info(f"test_encoder: iter {i}")
x = torch.rand(N, T, C)
torch_encoder_out, _, torch_states = torch_encoder_model(
x, x_lens, torch_states
)
torch_encoder_out = torch_encoder_proj_model(torch_encoder_out)
onnx_encoder_out = onnx_model.run_encoder(x)
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), (
(torch_encoder_out - onnx_encoder_out).abs().max()
)
def test_decoder(
torch_decoder_model: torch.jit.ScriptModule,
torch_decoder_proj_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
context_size = onnx_model.context_size
vocab_size = onnx_model.vocab_size
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_decoder: iter {i}, N={N}")
x = torch.randint(
low=1,
high=vocab_size,
size=(N, context_size),
dtype=torch.int64,
)
torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False]))
torch_decoder_out = torch_decoder_proj_model(torch_decoder_out)
torch_decoder_out = torch_decoder_out.squeeze(1)
onnx_decoder_out = onnx_model.run_decoder(x)
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
(torch_decoder_out - onnx_decoder_out).abs().max()
)
def test_joiner(
torch_joiner_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1]
decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1]
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_joiner: iter {i}, N={N}")
encoder_out = torch.rand(N, encoder_dim)
decoder_out = torch.rand(N, decoder_dim)
projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out)
projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out)
torch_joiner_out = torch_joiner_model(encoder_out, decoder_out)
onnx_joiner_out = onnx_model.run_joiner(
projected_encoder_out, projected_decoder_out
)
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
(torch_joiner_out - onnx_joiner_out).abs().max()
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_encoder_model = torch.jit.load(args.jit_encoder_filename)
torch_decoder_model = torch.jit.load(args.jit_decoder_filename)
torch_joiner_model = torch.jit.load(args.jit_joiner_filename)
onnx_model = OnnxModel(
encoder_model_filename=args.onnx_encoder_filename,
decoder_model_filename=args.onnx_decoder_filename,
joiner_model_filename=args.onnx_joiner_filename,
)
logging.info("Test encoder")
# When exporting the model to onnx, we have already put the encoder_proj
# inside the encoder.
test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model)
logging.info("Test decoder")
# When exporting the model to onnx, we have already put the decoder_proj
# inside the decoder.
test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model)
logging.info("Test joiner")
test_joiner(torch_joiner_model, onnx_model)
logging.info("Finished checking ONNX models")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# See https://github.com/pytorch/pytorch/issues/38342
# and https://github.com/pytorch/pytorch/issues/33354
#
# If we don't do this, the delay increases whenever there is
# a new request that changes the actual batch size.
# If you use `py-spy dump --pid <server-pid> --native`, you will
# see a lot of time is spent in re-compiling the torch script model.
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
torch.manual_seed(20230207)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py

View File

@ -1,231 +0,0 @@
from typing import Optional, Tuple
import torch
class OnnxStreamingEncoder(torch.nn.Module):
"""This class warps the streaming Zipformer to reduce the number of
state tensors for onnx.
https://github.com/k2-fsa/icefall/pull/831
"""
def __init__(self, encoder):
"""
Args:
encoder: An instance of Zipformer Class
"""
super().__init__()
self.model = encoder
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
len_cache: torch.tensor,
avg_cache: torch.tensor,
attn_cache: torch.tensor,
cnn_cache: torch.tensor,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
len_cache:
The cached numbers of past frames.
avg_cache:
The cached average tensors.
attn_cache:
The cached key tensors of the first attention modules.
The cached value tensors of the first attention modules.
The cached value tensors of the second attention modules.
cnn_cache:
The cached left contexts of the first convolution modules.
The cached left contexts of the second convolution modules.
Returns:
Return a tuple containing 2 tensors:
"""
num_encoder_layers = []
encoder_attention_dims = []
states = []
for i, encoder in enumerate(self.model.encoders):
num_encoder_layers.append(encoder.num_layers)
encoder_attention_dims.append(encoder.attention_dim)
len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B]
offset = 0
for num_layer in num_encoder_layers:
states.append(len_cache[offset : offset + num_layer])
offset += num_layer
avg_cache = avg_cache.transpose(0, 1) # [15, B, 384]
offset = 0
for num_layer in num_encoder_layers:
states.append(avg_cache[offset : offset + num_layer])
offset += num_layer
attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192]
left_context_len = attn_cache.shape[1]
offset = 0
for i, num_layer in enumerate(num_encoder_layers):
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[offset : offset + num_layer, : left_context_len // ds]
)
offset += num_layer
for i, num_layer in enumerate(num_encoder_layers):
encoder_attention_dim = encoder_attention_dims[i]
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[
offset : offset + num_layer,
: left_context_len // ds,
:,
: encoder_attention_dim // 2,
]
)
offset += num_layer
for i, num_layer in enumerate(num_encoder_layers):
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[
offset : offset + num_layer,
: left_context_len // ds,
:,
: encoder_attention_dim // 2,
]
)
offset += num_layer
cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1]
offset = 0
for num_layer in num_encoder_layers:
states.append(cnn_cache[offset : offset + num_layer])
offset += num_layer
for num_layer in num_encoder_layers:
states.append(cnn_cache[offset : offset + num_layer])
offset += num_layer
encoder_out, encoder_out_lens, new_states = self.model.streaming_forward(
x=x,
x_lens=x_lens,
states=states,
)
new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose(
0, 1
) # [B,15]
new_avg_cache = torch.cat(
states[self.model.num_encoders : 2 * self.model.num_encoders]
).transpose(
0, 1
) # [B,15,384]
new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose(
0, 1
) # [B,2*15,384,cnn_kernel-1]
assert len(set(encoder_attention_dims)) == 1
pad_tensors = [
torch.nn.functional.pad(
tensor,
(
0,
encoder_attention_dims[0] - tensor.shape[-1],
0,
0,
0,
left_context_len - tensor.shape[1],
0,
0,
),
)
for tensor in states[
2 * self.model.num_encoders : 5 * self.model.num_encoders
]
]
new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
return (
encoder_out,
encoder_out_lens,
new_len_cache,
new_avg_cache,
new_attn_cache,
new_cnn_cache,
)
class TritonOnnxDecoder(torch.nn.Module):
"""This class warps the Decoder in decoder.py
to remove the scalar input "need_pad".
Triton currently doesn't support scalar input.
https://github.com/triton-inference-server/server/issues/2333
"""
def __init__(
self,
decoder: torch.nn.Module,
):
"""
Args:
decoder: A instance of Decoder
"""
super().__init__()
self.model = decoder
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
# False to not pad the input. Should be False during inference.
need_pad = False
return self.model(y, need_pad)
class TritonOnnxJoiner(torch.nn.Module):
"""This class warps the Joiner in joiner.py
to remove the scalar input "project_input".
Triton currently doesn't support scalar input.
https://github.com/triton-inference-server/server/issues/2333
"project_input" is set to True.
Triton solutions only need export joiner to a single joiner.onnx.
"""
def __init__(
self,
joiner: torch.nn.Module,
):
super().__init__()
self.model = joiner
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
# Apply input projections encoder_proj and decoder_proj.
project_input = True
return self.model(encoder_out, decoder_out, project_input)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py

View File

@ -1,497 +0,0 @@
#!/usr/bin/env python3
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
"""
This script loads ONNX models exported by ./export-onnx.py
and uses them to decode waves.
1. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--decode-chunk-len 32 \
--exp-dir $repo/exp/
It will generate the following 3 files in $repo/exp
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
2. Run this file with the exported ONNX models
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
model_type = encoder_meta["model_type"]
assert model_type == "zipformer", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
num_encoder_layers = encoder_meta["num_encoder_layers"]
encoder_dims = encoder_meta["encoder_dims"]
attention_dims = encoder_meta["attention_dims"]
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
left_context_len = encoder_meta["left_context_len"]
def to_int_list(s):
return list(map(int, s.split(",")))
num_encoder_layers = to_int_list(num_encoder_layers)
encoder_dims = to_int_list(encoder_dims)
attention_dims = to_int_list(attention_dims)
cnn_module_kernels = to_int_list(cnn_module_kernels)
left_context_len = to_int_list(left_context_len)
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"encoder_dims: {encoder_dims}")
logging.info(f"attention_dims: {attention_dims}")
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
logging.info(f"left_context_len: {left_context_len}")
num_encoders = len(num_encoder_layers)
cached_len = []
cached_avg = []
cached_key = []
cached_val = []
cached_val2 = []
cached_conv1 = []
cached_conv2 = []
N = batch_size
for i in range(num_encoders):
cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64))
cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i]))
cached_key.append(
torch.zeros(
num_encoder_layers[i], left_context_len[i], N, attention_dims[i]
)
)
cached_val.append(
torch.zeros(
num_encoder_layers[i],
left_context_len[i],
N,
attention_dims[i] // 2,
)
)
cached_val2.append(
torch.zeros(
num_encoder_layers[i],
left_context_len[i],
N,
attention_dims[i] // 2,
)
)
cached_conv1.append(
torch.zeros(
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
)
)
cached_conv2.append(
torch.zeros(
num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1
)
)
self.cached_len = cached_len
self.cached_avg = cached_avg
self.cached_key = cached_key
self.cached_val = cached_val
self.cached_val2 = cached_val2
self.cached_conv1 = cached_conv1
self.cached_conv2 = cached_conv2
self.num_encoders = num_encoders
self.segment = T
self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
encoder_input = {"x": x.numpy()}
encoder_output = ["encoder_out"]
def build_states_input(states: List[torch.Tensor], name: str):
for i, s in enumerate(states):
if isinstance(s, torch.Tensor):
encoder_input[f"{name}_{i}"] = s.numpy()
else:
encoder_input[f"{name}_{i}"] = s
encoder_output.append(f"new_{name}_{i}")
build_states_input(self.cached_len, "cached_len")
build_states_input(self.cached_avg, "cached_avg")
build_states_input(self.cached_key, "cached_key")
build_states_input(self.cached_val, "cached_val")
build_states_input(self.cached_val2, "cached_val2")
build_states_input(self.cached_conv1, "cached_conv1")
build_states_input(self.cached_conv2, "cached_conv2")
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
num_encoders = self.num_encoders
self.cached_len = states[num_encoders * 0 : num_encoders * 1]
self.cached_avg = states[num_encoders * 1 : num_encoders * 2]
self.cached_key = states[num_encoders * 2 : num_encoders * 3]
self.cached_val = states[num_encoders * 3 : num_encoders * 4]
self.cached_val2 = states[num_encoders * 4 : num_encoders * 5]
self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6]
self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7]
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py

View File

@ -1,361 +0,0 @@
#!/usr/bin/env python3
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by
./pruned_transducer_stateless7_streaming/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py

View File

@ -1,214 +0,0 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, ActivationBalancer is replaced with an identity operator;
Whiten is also replaced with an identity operator;
BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import ActivationBalancer, BasicNorm, Whiten
from zipformer import PoolingModule
class PoolingModuleNoProj(nn.Module):
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x = x.cumsum(dim=0) # (T, N, C)
x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0)
# Cumulated numbers of frames from start
cum_mask = torch.arange(1, x.size(0) + 1, device=x.device)
cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N)
pooling_mask = (1.0 / cum_mask).unsqueeze(2)
# now pooling_mask: (T, N, 1)
x = x * pooling_mask # (T, N, C)
cached_len = cached_len + x.size(0)
cached_avg = x[-1]
return x, cached_len, cached_avg
class PoolingModuleWithProj(nn.Module):
def __init__(self, proj: torch.nn.Module):
super().__init__()
self.proj = proj
self.pooling = PoolingModuleNoProj()
def forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
def streaming_forward(
self,
x: torch.Tensor,
cached_len: torch.Tensor,
cached_avg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (T, N, C)
cached_len:
A tensor of shape (N,)
cached_avg:
A tensor of shape (N, C)
Returns:
Return a tuple containing:
- new_x
- new_cached_len
- new_cached_avg
"""
x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg)
return self.proj(x), cached_len, cached_avg
class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""
def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(basic_norm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm
def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj:
assert isinstance(pooling, PoolingModule), type(pooling)
return PoolingModuleWithProj(proj=pooling.proj)
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
elif isinstance(m, (ActivationBalancer, Whiten)):
d[name] = nn.Identity()
elif isinstance(m, PoolingModule) and is_pnnx:
d[name] = convert_pooling_module(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py

View File

@ -1,282 +0,0 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
streams: List[DecodeStream],
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape
assert B == len(streams)
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyp_tokens[i]

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py