mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix torchscript export to use tokens.txt instead of lang_dir (#1475)
This commit is contained in:
parent
c401a2646b
commit
8d39f9508b
25
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
25
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
@ -1,3 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
@ -20,7 +21,7 @@
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 29 \
|
||||
--avg 19
|
||||
|
||||
@ -45,12 +46,13 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -85,10 +87,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
default="data/lang_char/tokens.txt",
|
||||
help="Path to the tokens.txt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -122,10 +124,14 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
# Load tokens.txt here
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
# Load id of the <blk> token and the vocab size
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.unk_id = token_table["<unk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -152,6 +158,7 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
|
1
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
1
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
@ -22,7 +22,7 @@
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
@ -47,12 +47,13 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -98,10 +99,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -135,12 +136,14 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
# Load tokens.txt here
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
# Load id of the <blk> token and the vocab size
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.unk_id = token_table["<unk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -183,6 +186,7 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
|
1
egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
1
egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
@ -218,10 +218,9 @@ def export_decoder_model_jit_trace(
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
# TODO(fangjun): Change the function name since we are actually using
|
||||
# torch.jit.script instead of torch.jit.trace
|
||||
traced_model = torch.jit.script(decoder_model)
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
@ -159,6 +159,7 @@ def main():
|
||||
|
||||
# Load id of the <blk> token and the vocab size
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.unk_id = token_table["<unk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||
|
||||
logging.info(params)
|
||||
|
@ -91,7 +91,7 @@ class Decoder(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
"""
|
||||
embedding_out = self.embedding(y)
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
|
@ -26,7 +26,7 @@ Usage:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
@ -87,7 +87,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
@ -113,7 +113,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e
|
||||
ln -s pretrained.pt epoch-999.pt
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model False \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
|
@ -23,7 +23,7 @@
|
||||
Usage:
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--tokens ./data/lang_char/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 24 \
|
||||
--use-averaged-model True
|
||||
@ -50,8 +50,9 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -60,8 +61,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -118,13 +118,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
default="data/lang_char/tokens.txt",
|
||||
help="Path to the tokens.txt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -160,13 +157,14 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
bpe_model = params.lang_dir + "/bpe.model"
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(bpe_model)
|
||||
# Load tokens.txt here
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
# Load id of the <blk> token and the vocab size
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.unk_id = token_table["<unk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -256,6 +254,7 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
|
1
egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
Symbolic link
1
egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
@ -24,7 +24,7 @@ Usage:
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--jit 1
|
||||
@ -47,7 +47,7 @@ for how to use them.
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--jit-trace 1
|
||||
@ -63,7 +63,7 @@ Check ./jit_pretrained.py for usage.
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 10 \
|
||||
--avg 2
|
||||
|
||||
@ -91,14 +91,14 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -133,10 +133,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
default="data/lang_char/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -313,10 +313,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -20,7 +20,7 @@
|
||||
Usage for offline:
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 4 \
|
||||
--avg 1
|
||||
|
||||
@ -28,7 +28,7 @@ It will generate a file exp_dir/pretrained.pt for offline ASR.
|
||||
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 4 \
|
||||
--avg 1 \
|
||||
--jit True
|
||||
@ -38,7 +38,7 @@ It will generate a file exp_dir/cpu_jit.pt for offline ASR.
|
||||
Usage for streaming:
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 7 \
|
||||
--avg 1
|
||||
|
||||
@ -46,7 +46,7 @@ It will generate a file exp_dir/pretrained.pt for streaming ASR.
|
||||
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 7 \
|
||||
--avg 1 \
|
||||
--jit True
|
||||
@ -73,13 +73,13 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -114,10 +114,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
default="data/lang_char/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -152,10 +152,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user