do some changes for aishell/ASR/transducer_stateless/export.py

This commit is contained in:
luomingshuang 2022-05-06 11:27:41 +08:00
parent 00c48ec1f3
commit 7c6a9bd817

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -22,7 +23,7 @@
Usage: Usage:
./transducer_stateless/export.py \ ./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -33,20 +34,19 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/aishell/ASR
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 1 \ --max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model --lang-dir data/lang_char
""" """
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
@ -56,6 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, str2bool
@ -91,10 +92,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="The lang dir",
) )
parser.add_argument( parser.add_argument(
@ -194,12 +195,11 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model)
# params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
# <blk> is defined in local/train_bpe_model.py params.blank_id = 0
params.blank_id = sp.piece_to_id("<blk>") params.vocab_size = max(lexicon.tokens) + 1
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)