Add more doc.

This commit is contained in:
Fangjun Kuang 2022-08-01 19:51:23 +08:00
parent eb549bf5d0
commit 0d5edbc3f2
2 changed files with 30 additions and 2 deletions

View File

@ -93,7 +93,7 @@ class Decoder(nn.Module):
Return a tensor of shape (N, U, decoder_dim).
"""
if isinstance(need_pad, torch.Tensor):
# This if for torch.jit.trace(), which cannot handle the case
# This is for torch.jit.trace(), which cannot handle the case
# when the input argument is not a tensor.
need_pad = bool(need_pad)

View File

@ -19,6 +19,7 @@
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
@ -36,7 +37,7 @@ load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
It will also generates 3 other files: `encoder_jit_script.pt`,
It will also generate 3 other files: `encoder_jit_script.pt`,
`decoder_jit_script.pt`, and `joiner_jit_script.pt`.
(2) Export to torchscript model using torch.jit.trace()
@ -96,6 +97,18 @@ you can do:
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
"""
import argparse
@ -170,6 +183,13 @@ def get_parser():
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate 4 files:
- encoder_jit_script.pt
- decoder_jit_script.pt
- joiner_jit_script.pt
- cpu_jit.pt (which combines the above 3 files)
Check ./jit_pretrained.py for how to use them.
""",
)
@ -178,6 +198,12 @@ def get_parser():
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
@ -191,6 +217,8 @@ def get_parser():
- encoder.onnx
- decoder.onnx
- joiner.onnx
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)