mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Support exporting to ONNX format (#501)
* WIP: Support exporting to ONNX format * Minor fixes. * Combine encoder/decoder/joiner into a single file. * Revert merging three onnx models into a single one. It's quite time consuming to extract a sub-graph from the combined model. For instance, it takes more than one hour to extract the encoder model. * Update CI to test ONNX models. * Decode with exported models. * Fix typos. * Add more doc. * Remove ncnn as it is not fully tested yet. * Fix as_strided for streaming conformer.
This commit is contained in:
parent
132132f52a
commit
58a96e5b68
@ -22,8 +22,76 @@ ls -lh $repo/test_wavs/*.wav
|
|||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
||||||
|
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
log "Export to torchscript model"
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.onnx
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
log "Decode with ONNX models"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/onnx_check.py \
|
||||||
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.trace()"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.script()"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
|
||||||
|
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
|
||||||
|
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
log "Greedy search with --max-sym-per-frame $sym"
|
log "Greedy search with --max-sym-per-frame $sym"
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -155,7 +155,8 @@ class Conformer(EncoderInterface):
|
|||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
if not torch.jit.is_tracing():
|
||||||
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
@ -787,6 +788,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
||||||
|
# It assumes that the maximum input won't have more than
|
||||||
|
# 10k frames.
|
||||||
|
#
|
||||||
|
# TODO(fangjun): Use torch.jit.script() for this module
|
||||||
|
max_len = 10000
|
||||||
|
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
@ -992,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
|
||||||
time1 means the length of query vector.
|
time1 means the length of query vector.
|
||||||
left_context (int): left context (in frames) used during streaming decoding.
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
@ -1006,20 +1015,32 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
|
|
||||||
time2 = time1 + left_context
|
time2 = time1 + left_context
|
||||||
assert (
|
if not torch.jit.is_tracing():
|
||||||
n == left_context + 2 * time1 - 1
|
assert (
|
||||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
n == left_context + 2 * time1 - 1
|
||||||
|
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||||
|
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
if torch.jit.is_tracing():
|
||||||
batch_stride = x.stride(0)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
head_stride = x.stride(1)
|
cols = torch.arange(time2)
|
||||||
time1_stride = x.stride(2)
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
n_stride = x.stride(3)
|
indexes = rows + cols
|
||||||
return x.as_strided(
|
|
||||||
(batch_size, num_heads, time1, time2),
|
x = x.reshape(-1, n)
|
||||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
x = torch.gather(x, dim=1, index=indexes)
|
||||||
storage_offset=n_stride * (time1 - 1),
|
x = x.reshape(batch_size, num_heads, time1, time2)
|
||||||
)
|
return x
|
||||||
|
else:
|
||||||
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
|
batch_stride = x.stride(0)
|
||||||
|
head_stride = x.stride(1)
|
||||||
|
time1_stride = x.stride(2)
|
||||||
|
n_stride = x.stride(3)
|
||||||
|
return x.as_strided(
|
||||||
|
(batch_size, num_heads, time1, time2),
|
||||||
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||||
|
storage_offset=n_stride * (time1 - 1),
|
||||||
|
)
|
||||||
|
|
||||||
def multi_head_attention_forward(
|
def multi_head_attention_forward(
|
||||||
self,
|
self,
|
||||||
@ -1090,13 +1111,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
assert embed_dim == embed_dim_to_check
|
if not torch.jit.is_tracing():
|
||||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
assert embed_dim == embed_dim_to_check
|
||||||
|
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||||
|
|
||||||
head_dim = embed_dim // num_heads
|
head_dim = embed_dim // num_heads
|
||||||
assert (
|
if not torch.jit.is_tracing():
|
||||||
head_dim * num_heads == embed_dim
|
assert (
|
||||||
), "embed_dim must be divisible by num_heads"
|
head_dim * num_heads == embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
scaling = float(head_dim) ** -0.5
|
scaling = float(head_dim) ** -0.5
|
||||||
|
|
||||||
@ -1209,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
src_len = k.size(0)
|
src_len = k.size(0)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None and not torch.jit.is_tracing():
|
||||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||||
key_padding_mask.size(0), bsz
|
key_padding_mask.size(0), bsz
|
||||||
)
|
)
|
||||||
@ -1220,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||||
|
|
||||||
pos_emb_bsz = pos_emb.size(0)
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
if not torch.jit.is_tracing():
|
||||||
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
|
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||||
p = p.permute(0, 2, 3, 1)
|
p = p.permute(0, 2, 3, 1)
|
||||||
@ -1255,11 +1280,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz * num_heads, tgt_len, -1
|
bsz * num_heads, tgt_len, -1
|
||||||
)
|
)
|
||||||
|
|
||||||
assert list(attn_output_weights.size()) == [
|
if not torch.jit.is_tracing():
|
||||||
bsz * num_heads,
|
assert list(attn_output_weights.size()) == [
|
||||||
tgt_len,
|
bsz * num_heads,
|
||||||
src_len,
|
tgt_len,
|
||||||
]
|
src_len,
|
||||||
|
]
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
if attn_mask.dtype == torch.bool:
|
if attn_mask.dtype == torch.bool:
|
||||||
@ -1318,7 +1344,14 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
|
assert list(attn_output.size()) == [
|
||||||
|
bsz * num_heads,
|
||||||
|
tgt_len,
|
||||||
|
head_dim,
|
||||||
|
]
|
||||||
|
|
||||||
attn_output = (
|
attn_output = (
|
||||||
attn_output.transpose(0, 1)
|
attn_output.transpose(0, 1)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -77,7 +79,9 @@ class Decoder(nn.Module):
|
|||||||
# It is to support torch script
|
# It is to support torch script
|
||||||
self.conv = nn.Identity()
|
self.conv = nn.Identity()
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
def forward(
|
||||||
|
self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
@ -88,18 +92,24 @@ class Decoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, decoder_dim).
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
"""
|
"""
|
||||||
|
if isinstance(need_pad, torch.Tensor):
|
||||||
|
# This is for torch.jit.trace(), which cannot handle the case
|
||||||
|
# when the input argument is not a tensor.
|
||||||
|
need_pad = bool(need_pad)
|
||||||
|
|
||||||
y = y.to(torch.int64)
|
y = y.to(torch.int64)
|
||||||
embedding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad:
|
||||||
embedding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embedding_out, pad=(self.context_size - 1, 0)
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
assert embedding_out.size(-1) == self.context_size
|
if not torch.jit.is_tracing():
|
||||||
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
embedding_out = F.relu(embedding_out)
|
embedding_out = F.relu(embedding_out)
|
||||||
|
@ -52,10 +52,10 @@ class Joiner(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
assert encoder_out.ndim == decoder_out.ndim
|
assert encoder_out.ndim == decoder_out.ndim
|
||||||
assert encoder_out.ndim in (2, 4)
|
assert encoder_out.ndim in (2, 4)
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||||
|
@ -152,7 +152,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
if not torch.jit.is_tracing():
|
||||||
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
scales = (
|
scales = (
|
||||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||||
+ self.eps.exp()
|
+ self.eps.exp()
|
||||||
@ -423,7 +424,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
@ -472,7 +473,7 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x * torch.sigmoid(x - 1.0)
|
return x * torch.sigmoid(x - 1.0)
|
||||||
else:
|
else:
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
@ -19,14 +19,67 @@
|
|||||||
# This script converts several saved checkpoints
|
# This script converts several saved checkpoints
|
||||||
# to a single one using model averaging.
|
# to a single one using model averaging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.script()
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||||
|
|
||||||
|
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||||
|
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||||
|
|
||||||
|
It will also generate 3 other files: `encoder_jit_script.pt`,
|
||||||
|
`decoder_jit_script.pt`, and `joiner_jit_script.pt`.
|
||||||
|
|
||||||
|
(2) Export to torchscript model using torch.jit.trace()
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
It will generates 3 files: `encoder_jit_trace.pt`,
|
||||||
|
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
|
||||||
|
|
||||||
|
|
||||||
|
(3) Export to ONNX format
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
It will generate the following three files in the given `exp_dir`.
|
||||||
|
Check `onnx_check.py` for how to use them.
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
|
||||||
|
|
||||||
|
(4) Export `model.state_dict()`
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
It will generate a file exp_dir/pretrained.pt
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
To use the generated file with `pruned_transducer_stateless3/decode.py`,
|
To use the generated file with `pruned_transducer_stateless3/decode.py`,
|
||||||
you can do:
|
you can do:
|
||||||
@ -42,6 +95,20 @@ you can do:
|
|||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
Check ./pretrained.py for its usage.
|
||||||
|
|
||||||
|
Note: If you don't want to train a model from scratch, we have
|
||||||
|
provided one for you. You can get it at
|
||||||
|
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||||
|
|
||||||
|
with the following commands:
|
||||||
|
|
||||||
|
sudo apt-get install git-lfs
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||||
|
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -50,6 +117,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -114,6 +183,42 @@ def get_parser():
|
|||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""True to save a model after applying torch.jit.script.
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
It will generate 4 files:
|
||||||
|
- encoder_jit_script.pt
|
||||||
|
- decoder_jit_script.pt
|
||||||
|
- joiner_jit_script.pt
|
||||||
|
- cpu_jit.pt (which combines the above 3 files)
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for how to use them.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-trace",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.trace.
|
||||||
|
It will generate 3 files:
|
||||||
|
- encoder_jit_trace.pt
|
||||||
|
- decoder_jit_trace.pt
|
||||||
|
- joiner_jit_trace.pt
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for how to use them.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If True, --jit is ignored and it exports the model
|
||||||
|
to onnx format. Three files will be generated:
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
|
||||||
|
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -139,6 +244,275 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_script(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.script()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(encoder_model)
|
||||||
|
script_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_script(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.script()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(decoder_model)
|
||||||
|
script_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_script(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(joiner_model)
|
||||||
|
script_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T, C)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
# encoder_model = torch.jit.script(encoder_model)
|
||||||
|
# It throws the following error for the above statement
|
||||||
|
#
|
||||||
|
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||||
|
# 11 is not supported. Please feel free to request support or
|
||||||
|
# submit a pull request on PyTorch GitHub.
|
||||||
|
#
|
||||||
|
# I cannot find which statement causes the above error.
|
||||||
|
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||||
|
# works well for the current reworked model
|
||||||
|
warmup = 1.0
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, x_lens, warmup),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens", "warmup"],
|
||||||
|
output_names=["encoder_out", "encoder_out_lens"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||||
|
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||||
|
# in this case
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
(y, need_pad),
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y", "need_pad"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- joiner_out: a tensor of shape (N, vocab_size)
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
project_input = True
|
||||||
|
# Note: It uses torch.jit.trace() internally
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(encoder_out, decoder_out, project_input),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["encoder_out", "decoder_out", "project_input"],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
@ -165,7 +539,7 @@ def main():
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
@ -185,7 +559,9 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
elif params.avg == 1:
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
@ -196,14 +572,39 @@ def main():
|
|||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
model.eval()
|
)
|
||||||
|
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
if params.jit:
|
if params.onnx is True:
|
||||||
|
opset_version = 11
|
||||||
|
logging.info("Exporting to onnx format")
|
||||||
|
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
model.encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
model.decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
model.joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
elif params.jit is True:
|
||||||
|
logging.info("Using torch.jit.script()")
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
@ -214,8 +615,29 @@ def main():
|
|||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
model.save(str(filename))
|
model.save(str(filename))
|
||||||
logging.info(f"Saved to {filename}")
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
# Also export encoder/decoder/joiner separately
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
|
||||||
|
elif params.jit_trace is True:
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
else:
|
else:
|
||||||
logging.info("Not using torch.jit.script")
|
logging.info("Not using torchscript")
|
||||||
# Save it using a format so that it can be loaded
|
# Save it using a format so that it can be loaded
|
||||||
# by :func:`load_checkpoint`
|
# by :func:`load_checkpoint`
|
||||||
filename = params.exp_dir / "pretrained.pt"
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
338
egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
Executable file
338
egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
Executable file
@ -0,0 +1,338 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads torchscript models, either exported by `torch.jit.trace()`
|
||||||
|
or by `torch.jit.script()`, and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_trace.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_script.pt \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_script.pt \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_script.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: torch.jit.ScriptModule,
|
||||||
|
joiner: torch.jit.ScriptModule,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
context_size: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
decoder:
|
||||||
|
The decoder model.
|
||||||
|
joiner:
|
||||||
|
The joiner model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,).
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = encoder_out.device
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = packed_encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
)
|
||||||
|
decoder_out = decoder_out.squeeze(1)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
encoder = torch.jit.load(args.encoder_model_filename)
|
||||||
|
decoder = torch.jit.load(args.decoder_model_filename)
|
||||||
|
joiner = torch.jit.load(args.joiner_model_filename)
|
||||||
|
|
||||||
|
encoder.eval()
|
||||||
|
decoder.eval()
|
||||||
|
joiner.eval()
|
||||||
|
|
||||||
|
encoder.to(device)
|
||||||
|
decoder.to(device)
|
||||||
|
joiner.to(device)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = encoder(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_size=args.context_size,
|
||||||
|
)
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
words = sp.decode(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
199
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
Executable file
199
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
Executable file
@ -0,0 +1,199 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script checks that exported onnx models produce the same output
|
||||||
|
with the given torchscript model for the same input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ort.set_default_logger_severity(3)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the torchscript model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-encoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-decoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-joiner-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder(
|
||||||
|
model: torch.jit.ScriptModule,
|
||||||
|
encoder_session: ort.InferenceSession,
|
||||||
|
):
|
||||||
|
encoder_inputs = encoder_session.get_inputs()
|
||||||
|
assert encoder_inputs[0].name == "x"
|
||||||
|
assert encoder_inputs[1].name == "x_lens"
|
||||||
|
assert encoder_inputs[0].shape == ["N", "T", 80]
|
||||||
|
assert encoder_inputs[1].shape == ["N"]
|
||||||
|
|
||||||
|
for N in [1, 5]:
|
||||||
|
for T in [12, 25]:
|
||||||
|
print("N, T", N, T)
|
||||||
|
x = torch.rand(N, T, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||||
|
x_lens[0] = T
|
||||||
|
|
||||||
|
encoder_inputs = {
|
||||||
|
"x": x.numpy(),
|
||||||
|
"x_lens": x_lens.numpy(),
|
||||||
|
}
|
||||||
|
encoder_out, encoder_out_lens = encoder_session.run(
|
||||||
|
["encoder_out", "encoder_out_lens"],
|
||||||
|
encoder_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||||
|
|
||||||
|
encoder_out = torch.from_numpy(encoder_out)
|
||||||
|
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||||
|
(encoder_out - torch_encoder_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decoder(
|
||||||
|
model: torch.jit.ScriptModule,
|
||||||
|
decoder_session: ort.InferenceSession,
|
||||||
|
):
|
||||||
|
decoder_inputs = decoder_session.get_inputs()
|
||||||
|
assert decoder_inputs[0].name == "y"
|
||||||
|
assert decoder_inputs[0].shape == ["N", 2]
|
||||||
|
for N in [1, 5, 10]:
|
||||||
|
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||||
|
|
||||||
|
decoder_inputs = {"y": y.numpy()}
|
||||||
|
decoder_out = decoder_session.run(
|
||||||
|
["decoder_out"],
|
||||||
|
decoder_inputs,
|
||||||
|
)[0]
|
||||||
|
decoder_out = torch.from_numpy(decoder_out)
|
||||||
|
|
||||||
|
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||||
|
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
||||||
|
(decoder_out - torch_decoder_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_joiner(
|
||||||
|
model: torch.jit.ScriptModule,
|
||||||
|
joiner_session: ort.InferenceSession,
|
||||||
|
):
|
||||||
|
joiner_inputs = joiner_session.get_inputs()
|
||||||
|
assert joiner_inputs[0].name == "encoder_out"
|
||||||
|
assert joiner_inputs[0].shape == ["N", 512]
|
||||||
|
|
||||||
|
assert joiner_inputs[1].name == "decoder_out"
|
||||||
|
assert joiner_inputs[1].shape == ["N", 512]
|
||||||
|
|
||||||
|
for N in [1, 5, 10]:
|
||||||
|
encoder_out = torch.rand(N, 512)
|
||||||
|
decoder_out = torch.rand(N, 512)
|
||||||
|
|
||||||
|
joiner_inputs = {
|
||||||
|
"encoder_out": encoder_out.numpy(),
|
||||||
|
"decoder_out": decoder_out.numpy(),
|
||||||
|
}
|
||||||
|
joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
|
||||||
|
joiner_out = torch.from_numpy(joiner_out)
|
||||||
|
|
||||||
|
torch_joiner_out = model.joiner(
|
||||||
|
encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=True,
|
||||||
|
)
|
||||||
|
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
||||||
|
(joiner_out - torch_joiner_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = torch.jit.load(args.jit_filename)
|
||||||
|
|
||||||
|
options = ort.SessionOptions()
|
||||||
|
options.inter_op_num_threads = 1
|
||||||
|
options.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
logging.info("Test encoder")
|
||||||
|
encoder_session = ort.InferenceSession(
|
||||||
|
args.onnx_encoder_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
test_encoder(model, encoder_session)
|
||||||
|
|
||||||
|
logging.info("Test decoder")
|
||||||
|
decoder_session = ort.InferenceSession(
|
||||||
|
args.onnx_decoder_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
test_decoder(model, decoder_session)
|
||||||
|
|
||||||
|
logging.info("Test joiner")
|
||||||
|
joiner_session = ort.InferenceSession(
|
||||||
|
args.onnx_joiner_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
test_joiner(model, joiner_session)
|
||||||
|
logging.info("Finished checking ONNX models")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220727)
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
337
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
Executable file
337
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
Executable file
@ -0,0 +1,337 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX models and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_trace_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: ort.InferenceSession,
|
||||||
|
joiner: ort.InferenceSession,
|
||||||
|
encoder_out: np.ndarray,
|
||||||
|
encoder_out_lens: np.ndarray,
|
||||||
|
context_size: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
decoder:
|
||||||
|
The decoder model.
|
||||||
|
joiner:
|
||||||
|
The joiner model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,).
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
encoder_out = torch.from_numpy(encoder_out)
|
||||||
|
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input_nodes = decoder.get_inputs()
|
||||||
|
decoder_output_nodes = decoder.get_outputs()
|
||||||
|
|
||||||
|
joiner_input_nodes = joiner.get_inputs()
|
||||||
|
joiner_output_nodes = joiner.get_outputs()
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = decoder.run(
|
||||||
|
[decoder_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
|
},
|
||||||
|
)[0].squeeze(1)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = packed_encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = joiner.run(
|
||||||
|
[joiner_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
joiner_input_nodes[0].name: current_encoder_out.numpy(),
|
||||||
|
joiner_input_nodes[1].name: decoder_out,
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
logits = torch.from_numpy(logits)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = decoder.run(
|
||||||
|
[decoder_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
|
},
|
||||||
|
)[0].squeeze(1)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
encoder = ort.InferenceSession(
|
||||||
|
args.encoder_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = ort.InferenceSession(
|
||||||
|
args.decoder_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = ort.InferenceSession(
|
||||||
|
args.joiner_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_input_nodes = encoder.get_inputs()
|
||||||
|
encoder_out_nodes = encoder.get_outputs()
|
||||||
|
encoder_out, encoder_out_lens = encoder.run(
|
||||||
|
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
||||||
|
{
|
||||||
|
encoder_input_nodes[0].name: features.numpy(),
|
||||||
|
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_size=args.context_size,
|
||||||
|
)
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
words = sp.decode(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -15,7 +15,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
|
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
|
||||||
|
and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`,
|
||||||
|
and `nn.Conv2d`.
|
||||||
|
|
||||||
|
The scaled version are required only in the training time. It simplifies our
|
||||||
|
life by converting them their non-scaled version during inference time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
|
def _get_weight(self: torch.nn.Linear):
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bias(self: torch.nn.Linear):
|
||||||
|
return self.bias
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
||||||
|
"""Convert an instance of ScaledLinear to nn.Linear.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scaled_linear:
|
||||||
|
The layer to be converted.
|
||||||
|
Returns:
|
||||||
|
Return a linear layer. It satisfies:
|
||||||
|
|
||||||
|
scaled_linear(x) == linear(x)
|
||||||
|
|
||||||
|
for any given input tensor `x`.
|
||||||
|
"""
|
||||||
|
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
|
||||||
|
|
||||||
|
# if not hasattr(torch.nn.Linear, "get_weight"):
|
||||||
|
# torch.nn.Linear.get_weight = _get_weight
|
||||||
|
# torch.nn.Linear.get_bias = _get_bias
|
||||||
|
|
||||||
|
weight = scaled_linear.get_weight()
|
||||||
|
bias = scaled_linear.get_bias()
|
||||||
|
has_bias = bias is not None
|
||||||
|
|
||||||
|
linear = torch.nn.Linear(
|
||||||
|
in_features=scaled_linear.in_features,
|
||||||
|
out_features=scaled_linear.out_features,
|
||||||
|
bias=True, # otherwise, it throws errors when converting to PNNX format.
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
|
linear.weight.data.copy_(weight)
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
linear.bias.data.copy_(bias)
|
||||||
|
else:
|
||||||
|
linear.bias.data.zero_()
|
||||||
|
|
||||||
|
return linear
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d:
|
||||||
|
"""Convert an instance of ScaledConv1d to nn.Conv1d.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scaled_conv1d:
|
||||||
|
The layer to be converted.
|
||||||
|
Returns:
|
||||||
|
Return an instance of nn.Conv1d that has the same `forward()` behavior
|
||||||
|
of the given `scaled_conv1d`.
|
||||||
|
"""
|
||||||
|
assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d)
|
||||||
|
|
||||||
|
weight = scaled_conv1d.get_weight()
|
||||||
|
bias = scaled_conv1d.get_bias()
|
||||||
|
has_bias = bias is not None
|
||||||
|
|
||||||
|
conv1d = nn.Conv1d(
|
||||||
|
in_channels=scaled_conv1d.in_channels,
|
||||||
|
out_channels=scaled_conv1d.out_channels,
|
||||||
|
kernel_size=scaled_conv1d.kernel_size,
|
||||||
|
stride=scaled_conv1d.stride,
|
||||||
|
padding=scaled_conv1d.padding,
|
||||||
|
dilation=scaled_conv1d.dilation,
|
||||||
|
groups=scaled_conv1d.groups,
|
||||||
|
bias=scaled_conv1d.bias is not None,
|
||||||
|
padding_mode=scaled_conv1d.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv1d.weight.data.copy_(weight)
|
||||||
|
if has_bias:
|
||||||
|
conv1d.bias.data.copy_(bias)
|
||||||
|
|
||||||
|
return conv1d
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
|
||||||
|
"""Convert an instance of ScaledConv2d to nn.Conv2d.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scaled_conv2d:
|
||||||
|
The layer to be converted.
|
||||||
|
Returns:
|
||||||
|
Return an instance of nn.Conv2d that has the same `forward()` behavior
|
||||||
|
of the given `scaled_conv2d`.
|
||||||
|
"""
|
||||||
|
assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d)
|
||||||
|
|
||||||
|
weight = scaled_conv2d.get_weight()
|
||||||
|
bias = scaled_conv2d.get_bias()
|
||||||
|
has_bias = bias is not None
|
||||||
|
|
||||||
|
conv2d = nn.Conv2d(
|
||||||
|
in_channels=scaled_conv2d.in_channels,
|
||||||
|
out_channels=scaled_conv2d.out_channels,
|
||||||
|
kernel_size=scaled_conv2d.kernel_size,
|
||||||
|
stride=scaled_conv2d.stride,
|
||||||
|
padding=scaled_conv2d.padding,
|
||||||
|
dilation=scaled_conv2d.dilation,
|
||||||
|
groups=scaled_conv2d.groups,
|
||||||
|
bias=scaled_conv2d.bias is not None,
|
||||||
|
padding_mode=scaled_conv2d.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d.weight.data.copy_(weight)
|
||||||
|
if has_bias:
|
||||||
|
conv2d.bias.data.copy_(bias)
|
||||||
|
|
||||||
|
return conv2d
|
||||||
|
|
||||||
|
|
||||||
|
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||||
|
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
||||||
|
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
||||||
|
and `nn.Conv2d`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The model to be converted.
|
||||||
|
inplace:
|
||||||
|
If True, the input model is modified inplace.
|
||||||
|
If False, the input model is copied and we modify the copied version.
|
||||||
|
Return:
|
||||||
|
Return a model without scaled layers.
|
||||||
|
"""
|
||||||
|
if not inplace:
|
||||||
|
model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
excluded_patterns = r"self_attn\.(in|out)_proj"
|
||||||
|
p = re.compile(excluded_patterns)
|
||||||
|
|
||||||
|
d = {}
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if isinstance(m, ScaledLinear):
|
||||||
|
if p.search(name) is not None:
|
||||||
|
continue
|
||||||
|
d[name] = scaled_linear_to_linear(m)
|
||||||
|
elif isinstance(m, ScaledConv1d):
|
||||||
|
d[name] = scaled_conv1d_to_conv1d(m)
|
||||||
|
elif isinstance(m, ScaledConv2d):
|
||||||
|
d[name] = scaled_conv2d_to_conv2d(m)
|
||||||
|
|
||||||
|
for k, v in d.items():
|
||||||
|
if "." in k:
|
||||||
|
parent, child = k.rsplit(".", maxsplit=1)
|
||||||
|
setattr(model.get_submodule(parent), child, v)
|
||||||
|
else:
|
||||||
|
setattr(model, k, v)
|
||||||
|
|
||||||
|
return model
|
@ -0,0 +1,201 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless3/test_scaling_converter.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
|
||||||
|
from scaling_converter import (
|
||||||
|
convert_scaled_to_non_scaled,
|
||||||
|
scaled_conv1d_to_conv1d,
|
||||||
|
scaled_conv2d_to_conv2d,
|
||||||
|
scaled_linear_to_linear,
|
||||||
|
)
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
params.dynamic_chunk_training = False
|
||||||
|
params.short_chunk_size = 25
|
||||||
|
params.num_left_chunks = 4
|
||||||
|
params.causal_convolution = False
|
||||||
|
|
||||||
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_linear_to_linear():
|
||||||
|
N = 5
|
||||||
|
in_features = 10
|
||||||
|
out_features = 20
|
||||||
|
for bias in [True, False]:
|
||||||
|
scaled_linear = ScaledLinear(
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
linear = scaled_linear_to_linear(scaled_linear)
|
||||||
|
x = torch.rand(N, in_features)
|
||||||
|
|
||||||
|
y1 = scaled_linear(x)
|
||||||
|
y2 = linear(x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
|
||||||
|
jit_scaled_linear = torch.jit.script(scaled_linear)
|
||||||
|
jit_linear = torch.jit.script(linear)
|
||||||
|
|
||||||
|
y3 = jit_scaled_linear(x)
|
||||||
|
y4 = jit_linear(x)
|
||||||
|
|
||||||
|
assert torch.allclose(y3, y4)
|
||||||
|
assert torch.allclose(y1, y4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_conv1d_to_conv1d():
|
||||||
|
in_channels = 3
|
||||||
|
for bias in [True, False]:
|
||||||
|
scaled_conv1d = ScaledConv1d(
|
||||||
|
in_channels,
|
||||||
|
6,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
|
||||||
|
|
||||||
|
x = torch.rand(20, in_channels, 10)
|
||||||
|
y1 = scaled_conv1d(x)
|
||||||
|
y2 = conv1d(x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
|
||||||
|
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
|
||||||
|
jit_conv1d = torch.jit.script(conv1d)
|
||||||
|
|
||||||
|
y3 = jit_scaled_conv1d(x)
|
||||||
|
y4 = jit_conv1d(x)
|
||||||
|
|
||||||
|
assert torch.allclose(y3, y4)
|
||||||
|
assert torch.allclose(y1, y4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_conv2d_to_conv2d():
|
||||||
|
in_channels = 1
|
||||||
|
for bias in [True, False]:
|
||||||
|
scaled_conv2d = ScaledConv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=3,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
|
||||||
|
|
||||||
|
x = torch.rand(20, in_channels, 10, 20)
|
||||||
|
y1 = scaled_conv2d(x)
|
||||||
|
y2 = conv2d(x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
|
||||||
|
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
|
||||||
|
jit_conv2d = torch.jit.script(conv2d)
|
||||||
|
|
||||||
|
y3 = jit_scaled_conv2d(x)
|
||||||
|
y4 = jit_conv2d(x)
|
||||||
|
|
||||||
|
assert torch.allclose(y3, y4)
|
||||||
|
assert torch.allclose(y1, y4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_scaled_to_non_scaled():
|
||||||
|
for inplace in [False, True]:
|
||||||
|
model = get_model()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
orig_model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
|
||||||
|
|
||||||
|
model = orig_model
|
||||||
|
|
||||||
|
# test encoder
|
||||||
|
N = 2
|
||||||
|
T = 100
|
||||||
|
vocab_size = model.decoder.vocab_size
|
||||||
|
|
||||||
|
x = torch.randn(N, T, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.full((N,), x.size(1))
|
||||||
|
|
||||||
|
e1, e1_lens = model.encoder(x, x_lens)
|
||||||
|
e2, e2_lens = converted_model.encoder(x, x_lens)
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(e1_lens, e2_lens))
|
||||||
|
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
|
||||||
|
|
||||||
|
# test decoder
|
||||||
|
U = 50
|
||||||
|
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
|
||||||
|
|
||||||
|
d1 = model.decoder(y)
|
||||||
|
d2 = model.decoder(y)
|
||||||
|
|
||||||
|
assert torch.allclose(d1, d2)
|
||||||
|
|
||||||
|
# test simple projection
|
||||||
|
lm1 = model.simple_lm_proj(d1)
|
||||||
|
am1 = model.simple_am_proj(e1)
|
||||||
|
|
||||||
|
lm2 = converted_model.simple_lm_proj(d2)
|
||||||
|
am2 = converted_model.simple_am_proj(e2)
|
||||||
|
|
||||||
|
assert torch.allclose(lm1, lm2)
|
||||||
|
assert torch.allclose(am1, am2)
|
||||||
|
|
||||||
|
# test joiner
|
||||||
|
e = torch.rand(2, 3, 4, 512)
|
||||||
|
d = torch.rand(2, 3, 4, 512)
|
||||||
|
|
||||||
|
j1 = model.joiner(e, d)
|
||||||
|
j2 = converted_model.joiner(e, d)
|
||||||
|
assert torch.allclose(j1, j2)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
test_scaled_linear_to_linear()
|
||||||
|
test_scaled_conv1d_to_conv1d()
|
||||||
|
test_scaled_conv2d_to_conv2d()
|
||||||
|
test_convert_scaled_to_non_scaled()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220730)
|
||||||
|
main()
|
@ -436,13 +436,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(
|
||||||
|
params: AttributeDict,
|
||||||
|
enable_giga: bool = True,
|
||||||
|
) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
decoder_giga = get_decoder_model(params)
|
if enable_giga:
|
||||||
joiner_giga = get_joiner_model(params)
|
logging.info("Use giga")
|
||||||
|
decoder_giga = get_decoder_model(params)
|
||||||
|
joiner_giga = get_joiner_model(params)
|
||||||
|
else:
|
||||||
|
logging.info("Disable giga")
|
||||||
|
decoder_giga = None
|
||||||
|
joiner_giga = None
|
||||||
|
|
||||||
model = Transducer(
|
model = Transducer(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
|
@ -20,3 +20,6 @@ sentencepiece==0.1.96
|
|||||||
tensorboard==2.8.0
|
tensorboard==2.8.0
|
||||||
typeguard==2.13.3
|
typeguard==2.13.3
|
||||||
multi_quantization
|
multi_quantization
|
||||||
|
|
||||||
|
onnx
|
||||||
|
onnxruntime
|
||||||
|
@ -4,3 +4,5 @@ sentencepiece>=0.1.96
|
|||||||
tensorboard
|
tensorboard
|
||||||
typeguard
|
typeguard
|
||||||
multi_quantization
|
multi_quantization
|
||||||
|
onnx
|
||||||
|
onnxruntime
|
||||||
|
Loading…
x
Reference in New Issue
Block a user