mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Decode with exported models.
This commit is contained in:
parent
71ea196370
commit
f572e149a9
@ -42,15 +42,56 @@ log "Export to torchscript model"
|
|||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
ls -lh $repo/exp/*.onnx
|
ls -lh $repo/exp/*.onnx
|
||||||
ls -lh $repo/exp/*.pt
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
log "Decode with ONNX models"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/onnx_check.py \
|
./pruned_transducer_stateless3/onnx_check.py \
|
||||||
--jit-filename $repo/exp/cpu_jit.pt \
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.trace()"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.script()"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
|
||||||
|
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
|
||||||
|
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
log "Greedy search with --max-sym-per-frame $sym"
|
log "Greedy search with --max-sym-per-frame $sym"
|
||||||
|
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -77,7 +79,9 @@ class Decoder(nn.Module):
|
|||||||
# It is to support torch script
|
# It is to support torch script
|
||||||
self.conv = nn.Identity()
|
self.conv = nn.Identity()
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
def forward(
|
||||||
|
self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
@ -88,17 +92,23 @@ class Decoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, decoder_dim).
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
"""
|
"""
|
||||||
|
if isinstance(need_pad, torch.Tensor):
|
||||||
|
# This if for torch.jit.trace(), which cannot handle the case
|
||||||
|
# when the input argument is not a tensor.
|
||||||
|
need_pad = bool(need_pad)
|
||||||
|
|
||||||
y = y.to(torch.int64)
|
y = y.to(torch.int64)
|
||||||
embedding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad:
|
||||||
embedding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embedding_out, pad=(self.context_size - 1, 0)
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
assert embedding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
@ -52,7 +52,7 @@ class Joiner(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
assert encoder_out.ndim == decoder_out.ndim
|
assert encoder_out.ndim == decoder_out.ndim
|
||||||
assert encoder_out.ndim in (2, 4)
|
assert encoder_out.ndim in (2, 4)
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
(1) Export to torchscript model
|
(1) Export to torchscript model using torch.jit.script()
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
@ -36,7 +36,23 @@ load it by `torch.jit.load("cpu_jit.pt")`.
|
|||||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||||
|
|
||||||
(2) Export to ONNX format
|
It will also generates 3 other files: `encoder_jit_script.pt`,
|
||||||
|
`decoder_jit_script.pt`, and `joiner_jit_script.pt`.
|
||||||
|
|
||||||
|
(2) Export to torchscript model using torch.jit.trace()
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
It will generates 3 files: `encoder_jit_trace.pt`,
|
||||||
|
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
|
||||||
|
|
||||||
|
|
||||||
|
(3) Export to ONNX format
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
@ -53,7 +69,7 @@ Check `onnx_check.py` for how to use them.
|
|||||||
- joiner.onnx
|
- joiner.onnx
|
||||||
|
|
||||||
|
|
||||||
(3) Export `model.state_dict()`
|
(4) Export `model.state_dict()`
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
@ -78,6 +94,8 @@ you can do:
|
|||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
Check ./pretrained.py for its usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -87,6 +105,7 @@ from pathlib import Path
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -154,6 +173,14 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-trace",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.trace.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--onnx",
|
"--onnx",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -189,6 +216,128 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_script(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.script()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(encoder_model)
|
||||||
|
script_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_script(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.script()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(decoder_model)
|
||||||
|
script_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_script(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(joiner_model)
|
||||||
|
script_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
def export_encoder_model_onnx(
|
def export_encoder_model_onnx(
|
||||||
encoder_model: nn.Module,
|
encoder_model: nn.Module,
|
||||||
encoder_filename: str,
|
encoder_filename: str,
|
||||||
@ -262,6 +411,8 @@ def export_decoder_model_onnx(
|
|||||||
|
|
||||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
decoder_model:
|
decoder_model:
|
||||||
The decoder model to be exported.
|
The decoder model to be exported.
|
||||||
@ -399,6 +550,7 @@ def main():
|
|||||||
|
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
if params.onnx is True:
|
if params.onnx is True:
|
||||||
opset_version = 11
|
opset_version = 11
|
||||||
@ -424,6 +576,7 @@ def main():
|
|||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
)
|
)
|
||||||
elif params.jit is True:
|
elif params.jit is True:
|
||||||
|
logging.info("Using torch.jit.script()")
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
@ -434,8 +587,29 @@ def main():
|
|||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
model.save(str(filename))
|
model.save(str(filename))
|
||||||
logging.info(f"Saved to {filename}")
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
# Also export encoder/decoder/joiner separately
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
|
||||||
|
elif params.jit_trace is True:
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
else:
|
else:
|
||||||
logging.info("Not using torch.jit.script")
|
logging.info("Not using torchscript")
|
||||||
# Save it using a format so that it can be loaded
|
# Save it using a format so that it can be loaded
|
||||||
# by :func:`load_checkpoint`
|
# by :func:`load_checkpoint`
|
||||||
filename = params.exp_dir / "pretrained.pt"
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
338
egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
Executable file
338
egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
Executable file
@ -0,0 +1,338 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads torchscript models, either exported by `torch.jit.trace()`
|
||||||
|
or by `torch.jit.script()`, and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_trace.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_script.pt \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_script.pt \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_script.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: torch.jit.ScriptModule,
|
||||||
|
joiner: torch.jit.ScriptModule,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
context_size: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
decoder:
|
||||||
|
The decoder model.
|
||||||
|
joiner:
|
||||||
|
The joiner model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,).
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = encoder_out.device
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = packed_encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
)
|
||||||
|
decoder_out = decoder_out.squeeze(1)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
encoder = torch.jit.load(args.encoder_model_filename)
|
||||||
|
decoder = torch.jit.load(args.decoder_model_filename)
|
||||||
|
joiner = torch.jit.load(args.joiner_model_filename)
|
||||||
|
|
||||||
|
encoder.eval()
|
||||||
|
decoder.eval()
|
||||||
|
joiner.eval()
|
||||||
|
|
||||||
|
encoder.to(device)
|
||||||
|
decoder.to(device)
|
||||||
|
joiner.to(device)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = encoder(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_size=args.context_size,
|
||||||
|
)
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
words = sp.decode(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
161
egs/librispeech/ASR/pruned_transducer_stateless3/ncnn_pretrained.py
Executable file
161
egs/librispeech/ASR/pruned_transducer_stateless3/ncnn_pretrained.py
Executable file
@ -0,0 +1,161 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ncnn models and uses them to decode waves.
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||||
|
--model-dir /path/to/ncnn/model_dir
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
We assume there exist following files in the given `model_dir`:
|
||||||
|
|
||||||
|
- encoder_jit_trace.ncnn.param
|
||||||
|
- encoder_jit_trace.ncnn.bin
|
||||||
|
- decoder_jit_trace.ncnn.param
|
||||||
|
- decoder_jit_trace.ncnn.bin
|
||||||
|
- joiner_jit_trace.ncnn.param
|
||||||
|
- joiner_jit_trace.ncnn.bin
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import ncnn
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the ncnn models directory. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model_dir = Path(args.model_dir)
|
||||||
|
encoder_param = model_dir / "encoder_jit_trace.ncnn.param"
|
||||||
|
encoder_bin = model_dir / "encoder_jit_trace.ncnn.bin"
|
||||||
|
|
||||||
|
decoder_param = model_dir / "decoder_jit_trace.ncnn.param"
|
||||||
|
decoder_bin = model_dir / "decoder_jit_trace.ncnn.bin"
|
||||||
|
|
||||||
|
joiner_param = model_dir / "joiner_jit_trace.ncnn.param"
|
||||||
|
joiner_bin = model_dir / "joiner_jit_trace.ncnn.bin"
|
||||||
|
|
||||||
|
assert encoder_param.is_file()
|
||||||
|
assert encoder_bin.is_file()
|
||||||
|
|
||||||
|
assert decoder_param.is_file()
|
||||||
|
assert decoder_bin.is_file()
|
||||||
|
|
||||||
|
assert joiner_param.is_file()
|
||||||
|
assert joiner_bin.is_file()
|
||||||
|
|
||||||
|
encoder = ncnn.Net()
|
||||||
|
decoder = ncnn.Net()
|
||||||
|
joiner = ncnn.Net()
|
||||||
|
|
||||||
|
# encoder.load_param(str(encoder_param)) # not working yet
|
||||||
|
# decoder.load_param(str(decoder_param))
|
||||||
|
joiner.load_param(str(joiner_param))
|
||||||
|
|
||||||
|
encoder.clear()
|
||||||
|
decoder.clear()
|
||||||
|
joiner.clear()
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
337
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
Executable file
337
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
Executable file
@ -0,0 +1,337 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX models and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/jit_trace_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: ort.InferenceSession,
|
||||||
|
joiner: ort.InferenceSession,
|
||||||
|
encoder_out: np.ndarray,
|
||||||
|
encoder_out_lens: np.ndarray,
|
||||||
|
context_size: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
decoder:
|
||||||
|
The decoder model.
|
||||||
|
joiner:
|
||||||
|
The joiner model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,).
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
encoder_out = torch.from_numpy(encoder_out)
|
||||||
|
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input_nodes = decoder.get_inputs()
|
||||||
|
decoder_output_nodes = decoder.get_outputs()
|
||||||
|
|
||||||
|
joiner_input_nodes = joiner.get_inputs()
|
||||||
|
joiner_output_nodes = joiner.get_outputs()
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = decoder.run(
|
||||||
|
[decoder_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
|
},
|
||||||
|
)[0].squeeze(1)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = packed_encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = joiner.run(
|
||||||
|
[joiner_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
joiner_input_nodes[0].name: current_encoder_out.numpy(),
|
||||||
|
joiner_input_nodes[1].name: decoder_out,
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
logits = torch.from_numpy(logits)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = decoder.run(
|
||||||
|
[decoder_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
|
},
|
||||||
|
)[0].squeeze(1)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
encoder = ort.InferenceSession(
|
||||||
|
args.encoder_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = ort.InferenceSession(
|
||||||
|
args.decoder_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = ort.InferenceSession(
|
||||||
|
args.joiner_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_input_nodes = encoder.get_inputs()
|
||||||
|
encoder_out_nodes = encoder.get_outputs()
|
||||||
|
encoder_out, encoder_out_lens = encoder.run(
|
||||||
|
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
||||||
|
{
|
||||||
|
encoder_input_nodes[0].name: features.numpy(),
|
||||||
|
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_size=args.context_size,
|
||||||
|
)
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
words = sp.decode(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -15,7 +15,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
|
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
|
||||||
|
and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`,
|
||||||
|
and `nn.Conv2d`.
|
||||||
|
|
||||||
|
The scaled version are required only in the training time. It simplifies our
|
||||||
|
life by converting them their non-scaled version during inference time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
|
def _get_weight(self: torch.nn.Linear):
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bias(self: torch.nn.Linear):
|
||||||
|
return self.bias
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
||||||
|
"""Convert an instance of ScaledLinear to nn.Linear.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scaled_linear:
|
||||||
|
The layer to be converted.
|
||||||
|
Returns:
|
||||||
|
Return a linear layer. It satisfies:
|
||||||
|
|
||||||
|
scaled_linear(x) == linear(x)
|
||||||
|
|
||||||
|
for any given input tensor `x`.
|
||||||
|
"""
|
||||||
|
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
|
||||||
|
|
||||||
|
# if not hasattr(torch.nn.Linear, "get_weight"):
|
||||||
|
# torch.nn.Linear.get_weight = _get_weight
|
||||||
|
# torch.nn.Linear.get_bias = _get_bias
|
||||||
|
|
||||||
|
weight = scaled_linear.get_weight()
|
||||||
|
bias = scaled_linear.get_bias()
|
||||||
|
has_bias = bias is not None
|
||||||
|
|
||||||
|
linear = torch.nn.Linear(
|
||||||
|
in_features=scaled_linear.in_features,
|
||||||
|
out_features=scaled_linear.out_features,
|
||||||
|
bias=True, # otherwise, it throws errors when converting to PNNX format.
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
|
linear.weight.data.copy_(weight)
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
linear.bias.data.copy_(bias)
|
||||||
|
else:
|
||||||
|
linear.bias.data.zero_()
|
||||||
|
|
||||||
|
return linear
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d:
|
||||||
|
"""Convert an instance of ScaledConv1d to nn.Conv1d.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scaled_conv1d:
|
||||||
|
The layer to be converted.
|
||||||
|
Returns:
|
||||||
|
Return an instance of nn.Conv1d that has the same `forward()` behavior
|
||||||
|
of the given `scaled_conv1d`.
|
||||||
|
"""
|
||||||
|
assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d)
|
||||||
|
|
||||||
|
weight = scaled_conv1d.get_weight()
|
||||||
|
bias = scaled_conv1d.get_bias()
|
||||||
|
has_bias = bias is not None
|
||||||
|
|
||||||
|
conv1d = nn.Conv1d(
|
||||||
|
in_channels=scaled_conv1d.in_channels,
|
||||||
|
out_channels=scaled_conv1d.out_channels,
|
||||||
|
kernel_size=scaled_conv1d.kernel_size,
|
||||||
|
stride=scaled_conv1d.stride,
|
||||||
|
padding=scaled_conv1d.padding,
|
||||||
|
dilation=scaled_conv1d.dilation,
|
||||||
|
groups=scaled_conv1d.groups,
|
||||||
|
bias=scaled_conv1d.bias is not None,
|
||||||
|
padding_mode=scaled_conv1d.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv1d.weight.data.copy_(weight)
|
||||||
|
if has_bias:
|
||||||
|
conv1d.bias.data.copy_(bias)
|
||||||
|
|
||||||
|
return conv1d
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
|
||||||
|
"""Convert an instance of ScaledConv2d to nn.Conv2d.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scaled_conv2d:
|
||||||
|
The layer to be converted.
|
||||||
|
Returns:
|
||||||
|
Return an instance of nn.Conv2d that has the same `forward()` behavior
|
||||||
|
of the given `scaled_conv2d`.
|
||||||
|
"""
|
||||||
|
assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d)
|
||||||
|
|
||||||
|
weight = scaled_conv2d.get_weight()
|
||||||
|
bias = scaled_conv2d.get_bias()
|
||||||
|
has_bias = bias is not None
|
||||||
|
|
||||||
|
conv2d = nn.Conv2d(
|
||||||
|
in_channels=scaled_conv2d.in_channels,
|
||||||
|
out_channels=scaled_conv2d.out_channels,
|
||||||
|
kernel_size=scaled_conv2d.kernel_size,
|
||||||
|
stride=scaled_conv2d.stride,
|
||||||
|
padding=scaled_conv2d.padding,
|
||||||
|
dilation=scaled_conv2d.dilation,
|
||||||
|
groups=scaled_conv2d.groups,
|
||||||
|
bias=scaled_conv2d.bias is not None,
|
||||||
|
padding_mode=scaled_conv2d.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d.weight.data.copy_(weight)
|
||||||
|
if has_bias:
|
||||||
|
conv2d.bias.data.copy_(bias)
|
||||||
|
|
||||||
|
return conv2d
|
||||||
|
|
||||||
|
|
||||||
|
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||||
|
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
||||||
|
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
||||||
|
and `nn.Conv2d`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The model to be converted.
|
||||||
|
inplace:
|
||||||
|
If True, the input model is modified inplace.
|
||||||
|
If False, the input model is copied and we modify the copied version.
|
||||||
|
Return:
|
||||||
|
Return a model without scaled layers.
|
||||||
|
"""
|
||||||
|
if not inplace:
|
||||||
|
model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
excluded_patterns = r"self_attn\.(in|out)_proj"
|
||||||
|
p = re.compile(excluded_patterns)
|
||||||
|
|
||||||
|
d = {}
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if isinstance(m, ScaledLinear):
|
||||||
|
if p.search(name) is not None:
|
||||||
|
continue
|
||||||
|
d[name] = scaled_linear_to_linear(m)
|
||||||
|
elif isinstance(m, ScaledConv1d):
|
||||||
|
d[name] = scaled_conv1d_to_conv1d(m)
|
||||||
|
elif isinstance(m, ScaledConv2d):
|
||||||
|
d[name] = scaled_conv2d_to_conv2d(m)
|
||||||
|
|
||||||
|
for k, v in d.items():
|
||||||
|
if "." in k:
|
||||||
|
parent, child = k.rsplit(".", maxsplit=1)
|
||||||
|
setattr(model.get_submodule(parent), child, v)
|
||||||
|
else:
|
||||||
|
setattr(model, k, v)
|
||||||
|
|
||||||
|
return model
|
@ -0,0 +1,201 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless3/test_scaling_converter.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
|
||||||
|
from scaling_converter import (
|
||||||
|
convert_scaled_to_non_scaled,
|
||||||
|
scaled_conv1d_to_conv1d,
|
||||||
|
scaled_conv2d_to_conv2d,
|
||||||
|
scaled_linear_to_linear,
|
||||||
|
)
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
params.dynamic_chunk_training = False
|
||||||
|
params.short_chunk_size = 25
|
||||||
|
params.num_left_chunks = 4
|
||||||
|
params.causal_convolution = False
|
||||||
|
|
||||||
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_linear_to_linear():
|
||||||
|
N = 5
|
||||||
|
in_features = 10
|
||||||
|
out_features = 20
|
||||||
|
for bias in [True, False]:
|
||||||
|
scaled_linear = ScaledLinear(
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
linear = scaled_linear_to_linear(scaled_linear)
|
||||||
|
x = torch.rand(N, in_features)
|
||||||
|
|
||||||
|
y1 = scaled_linear(x)
|
||||||
|
y2 = linear(x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
|
||||||
|
jit_scaled_linear = torch.jit.script(scaled_linear)
|
||||||
|
jit_linear = torch.jit.script(linear)
|
||||||
|
|
||||||
|
y3 = jit_scaled_linear(x)
|
||||||
|
y4 = jit_linear(x)
|
||||||
|
|
||||||
|
assert torch.allclose(y3, y4)
|
||||||
|
assert torch.allclose(y1, y4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_conv1d_to_conv1d():
|
||||||
|
in_channels = 3
|
||||||
|
for bias in [True, False]:
|
||||||
|
scaled_conv1d = ScaledConv1d(
|
||||||
|
in_channels,
|
||||||
|
6,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
|
||||||
|
|
||||||
|
x = torch.rand(20, in_channels, 10)
|
||||||
|
y1 = scaled_conv1d(x)
|
||||||
|
y2 = conv1d(x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
|
||||||
|
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
|
||||||
|
jit_conv1d = torch.jit.script(conv1d)
|
||||||
|
|
||||||
|
y3 = jit_scaled_conv1d(x)
|
||||||
|
y4 = jit_conv1d(x)
|
||||||
|
|
||||||
|
assert torch.allclose(y3, y4)
|
||||||
|
assert torch.allclose(y1, y4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_conv2d_to_conv2d():
|
||||||
|
in_channels = 1
|
||||||
|
for bias in [True, False]:
|
||||||
|
scaled_conv2d = ScaledConv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=3,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
|
||||||
|
|
||||||
|
x = torch.rand(20, in_channels, 10, 20)
|
||||||
|
y1 = scaled_conv2d(x)
|
||||||
|
y2 = conv2d(x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
|
||||||
|
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
|
||||||
|
jit_conv2d = torch.jit.script(conv2d)
|
||||||
|
|
||||||
|
y3 = jit_scaled_conv2d(x)
|
||||||
|
y4 = jit_conv2d(x)
|
||||||
|
|
||||||
|
assert torch.allclose(y3, y4)
|
||||||
|
assert torch.allclose(y1, y4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_scaled_to_non_scaled():
|
||||||
|
for inplace in [False, True]:
|
||||||
|
model = get_model()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
orig_model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
|
||||||
|
|
||||||
|
model = orig_model
|
||||||
|
|
||||||
|
# test encoder
|
||||||
|
N = 2
|
||||||
|
T = 100
|
||||||
|
vocab_size = model.decoder.vocab_size
|
||||||
|
|
||||||
|
x = torch.randn(N, T, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.full((N,), x.size(1))
|
||||||
|
|
||||||
|
e1, e1_lens = model.encoder(x, x_lens)
|
||||||
|
e2, e2_lens = converted_model.encoder(x, x_lens)
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(e1_lens, e2_lens))
|
||||||
|
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
|
||||||
|
|
||||||
|
# test decoder
|
||||||
|
U = 50
|
||||||
|
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
|
||||||
|
|
||||||
|
d1 = model.decoder(y)
|
||||||
|
d2 = model.decoder(y)
|
||||||
|
|
||||||
|
assert torch.allclose(d1, d2)
|
||||||
|
|
||||||
|
# test simple projection
|
||||||
|
lm1 = model.simple_lm_proj(d1)
|
||||||
|
am1 = model.simple_am_proj(e1)
|
||||||
|
|
||||||
|
lm2 = converted_model.simple_lm_proj(d2)
|
||||||
|
am2 = converted_model.simple_am_proj(e2)
|
||||||
|
|
||||||
|
assert torch.allclose(lm1, lm2)
|
||||||
|
assert torch.allclose(am1, am2)
|
||||||
|
|
||||||
|
# test joiner
|
||||||
|
e = torch.rand(2, 3, 4, 512)
|
||||||
|
d = torch.rand(2, 3, 4, 512)
|
||||||
|
|
||||||
|
j1 = model.joiner(e, d)
|
||||||
|
j2 = converted_model.joiner(e, d)
|
||||||
|
assert torch.allclose(j1, j2)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
test_scaled_linear_to_linear()
|
||||||
|
test_scaled_conv1d_to_conv1d()
|
||||||
|
test_scaled_conv2d_to_conv2d()
|
||||||
|
test_convert_scaled_to_non_scaled()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220730)
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user