mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Support torch script. (#65)
* WIP: Support torchscript. * Minor fixes. * Fix style issues. * Add documentation about how to deploy a trained model.
This commit is contained in:
parent
d54828e73a
commit
beb54ddb61
15
README.md
15
README.md
@ -55,7 +55,22 @@ The WER for this model is:
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
|
||||
|
||||
|
||||
## Deployment with C++
|
||||
|
||||
Once you have trained a model in icefall, you may want to deploy it with C++,
|
||||
without Python dependencies.
|
||||
|
||||
Please refer to the documentation
|
||||
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html#deployment-with-c>
|
||||
for how to do this.
|
||||
|
||||
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.
|
||||
Please see: [](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing)
|
||||
|
||||
|
||||
[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
|
||||
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
|
||||
[yesno]: egs/yesno/ASR
|
||||
[librispeech]: egs/librispeech/ASR
|
||||
[k2]: https://github.com/k2-fsa/k2
|
||||
|
@ -20,6 +20,7 @@ In this tutorial, you will learn:
|
||||
- (2) How to start the training, either with a single GPU or multiple GPUs
|
||||
- (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring
|
||||
- (4) How to use a pre-trained model, provided by us
|
||||
- (5) How to deploy your trained model in C++, without Python dependencies
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
@ -292,12 +293,12 @@ The commonly used options are:
|
||||
|
||||
- ``--method``
|
||||
|
||||
This specifies the decoding method. This script supports 7 decoding methods.
|
||||
As for ctc decoding, it uses a sentence piece model to convert word pieces to words.
|
||||
This specifies the decoding method. This script supports 7 decoding methods.
|
||||
As for ctc decoding, it uses a sentence piece model to convert word pieces to words.
|
||||
And it needs neither a lexicon nor an n-gram LM.
|
||||
|
||||
|
||||
For example, the following command uses CTC topology for decoding:
|
||||
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
@ -334,20 +335,20 @@ Usage:
|
||||
--exp-dir conformer_ctc/exp \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--method ctc-decoding
|
||||
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2021-09-26 12:44:31,033 INFO [decode.py:537] Decoding started
|
||||
2021-09-26 12:44:31,033 INFO [decode.py:538]
|
||||
{'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True,
|
||||
2021-09-26 12:44:31,033 INFO [decode.py:538]
|
||||
{'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True,
|
||||
'feature_dim': 80, 'nhead': 8, 'attention_dim': 512, 'num_decoder_layers': 6, 'search_beam': 20, 'output_beam': 8,
|
||||
'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True,
|
||||
'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5,
|
||||
'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False,
|
||||
'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30,
|
||||
'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False,
|
||||
'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True,
|
||||
'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5,
|
||||
'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False,
|
||||
'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30,
|
||||
'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False,
|
||||
'shuffle': True, 'return_cuts': True, 'num_workers': 2}
|
||||
2021-09-26 12:44:31,406 INFO [lexicon.py:113] Loading pre-compiled data/lang_bpe_500/Linv.pt
|
||||
2021-09-26 12:44:31,464 INFO [decode.py:548] device: cuda:0
|
||||
@ -373,7 +374,7 @@ The output is given below:
|
||||
For test-other, WER of different settings are:
|
||||
ctc-decoding 8.21 best for test-other
|
||||
|
||||
2021-09-26 12:47:16,433 INFO [decode.py:680] Done!
|
||||
2021-09-26 12:47:16,433 INFO [decode.py:680] Done!
|
||||
|
||||
Pre-trained Model
|
||||
-----------------
|
||||
@ -693,3 +694,119 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
|
||||
|
||||
**Congratulations!** You have finished the librispeech ASR recipe with
|
||||
conformer CTC models in ``icefall``.
|
||||
|
||||
If you want to deploy your trained model in C++, please read the following section.
|
||||
|
||||
Deployment with C++
|
||||
-------------------
|
||||
|
||||
This section describes how to deploy your trained model in C++, without
|
||||
Python dependencies.
|
||||
|
||||
We assume you have run ``./prepare.sh`` and have the following directories available:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
data
|
||||
|-- lang_bpe
|
||||
|
||||
Also, we assume your checkpoints are saved in ``conformer_ctc/exp``.
|
||||
|
||||
If you know that averaging 20 checkpoints starting from ``epoch-30.pt`` yields the
|
||||
lowest WER, you can run the following commands
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
$ ./conformer_ctc/export.py \
|
||||
--epoch 30 \
|
||||
--avg 20 \
|
||||
--jit 1 \
|
||||
--lang-dir data/lang_bpe \
|
||||
--exp-dir conformer_ctc/exp
|
||||
|
||||
to get a torch scripted model saved in ``conformer_ctc/exp/cpu_jit.pt``.
|
||||
|
||||
Now you have all needed files ready. Let us compile k2 from source:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd $HOME
|
||||
$ git clone https://github.com/k2-fsa/k2
|
||||
$ cd k2
|
||||
$ git checkout v2.0-pre
|
||||
|
||||
.. CAUTION::
|
||||
|
||||
You have to switch to the branch ``v2.0-pre``!
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ mkdir build-release
|
||||
$ cd build-release
|
||||
$ cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
$ make -j decode
|
||||
# You will find an executable: `./bin/decode`
|
||||
|
||||
Now you are ready to go!
|
||||
|
||||
To view the usage of ``./bin/decode``, run:
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ ./bin/decode
|
||||
|
||||
It will show you the following message:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Please provide --jit_pt
|
||||
|
||||
(1) CTC decoding
|
||||
./bin/decode \
|
||||
--use_ctc_decoding true \
|
||||
--jit_pt <path to exported torch script pt file> \
|
||||
--bpe_model <path to pretrained BPE model> \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
<more wave files if any>
|
||||
(2) HLG decoding
|
||||
./bin/decode \
|
||||
--use_ctc_decoding false \
|
||||
--jit_pt <path to exported torch script pt file> \
|
||||
--hlg <path to HLG.pt> \
|
||||
--word-table <path to words.txt> \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
<more wave files if any>
|
||||
|
||||
--use_gpu false to use CPU
|
||||
--use_gpu true to use GPU
|
||||
|
||||
``./bin/decode`` supports two types of decoding at present: CTC decoding and HLG decoding.
|
||||
|
||||
CTC decoding
|
||||
^^^^^^^^^^^^
|
||||
|
||||
You need to provide:
|
||||
|
||||
- ``--jit_pt``, this is the file generated by ``conformer_ctc/export.py``. You can find it
|
||||
in ``conformer_ctc/exp/cpu_jit.pt``.
|
||||
- ``--bpe_model``, this is a sentence piece model generated by ``prepare.sh``. You can find
|
||||
it in ``data/lang_bpe/bpe.model``.
|
||||
|
||||
|
||||
HLG decoding
|
||||
^^^^^^^^^^^^
|
||||
|
||||
You need to provide:
|
||||
|
||||
- ``--jit_pt``, this is the same file as in CTC decoding.
|
||||
- ``--hlg``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/HLG.pt``.
|
||||
- ``--word-table``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/words.txt``.
|
||||
|
||||
We do provide a Colab notebook, showing you how to run a torch scripted model in C++.
|
||||
Please see |librispeech asr conformer ctc torch script colab notebook|
|
||||
|
||||
.. |librispeech asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
|
||||
:target: https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing
|
||||
|
165
egs/librispeech/ASR/conformer_ctc/export.py
Executable file
165
egs/librispeech/ASR/conformer_ctc/export.py
Executable file
@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=34,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe",
|
||||
help="""It contains language related input files such as "lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"use_feat_batchnorm": True,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 6,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=False,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -236,6 +236,7 @@ class Transformer(nn.Module):
|
||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def decoder_forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
@ -264,11 +265,15 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device)
|
||||
@ -301,6 +306,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return decoder_loss
|
||||
|
||||
@torch.jit.export
|
||||
def decoder_nll(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
@ -331,11 +337,15 @@ class Transformer(nn.Module):
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||
@ -649,7 +659,8 @@ class PositionalEncoding(nn.Module):
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.pe = None
|
||||
# not doing: self.pe = None because of errors thrown by torchscript
|
||||
self.pe = torch.zeros(0, self.d_model, dtype=torch.float32)
|
||||
|
||||
def extend_pe(self, x: torch.Tensor) -> None:
|
||||
"""Extend the time t in the positional encoding if required.
|
||||
@ -666,8 +677,7 @@ class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
@ -972,10 +982,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
||||
Return a new list-of-list, where each sublist starts
|
||||
with SOS ID.
|
||||
"""
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append([sos_id] + utt)
|
||||
return ans
|
||||
return [[sos_id] + utt for utt in token_ids]
|
||||
|
||||
|
||||
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
@ -992,7 +999,4 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
Return a new list-of-list, where each sublist ends
|
||||
with EOS ID.
|
||||
"""
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append(utt + [eos_id])
|
||||
return ans
|
||||
return [utt + [eos_id] for utt in token_ids]
|
||||
|
@ -41,6 +41,7 @@ dl_dir=$PWD/download
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
5000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
|
Loading…
x
Reference in New Issue
Block a user