Do some changes for aishell/ASR/transducer stateless/export.py (#347)

* do some changes for aishell/ASR/transducer_stateless/export.py
This commit is contained in:
Mingshuang Luo 2022-05-07 11:09:31 +08:00 committed by GitHub
parent c059ef3169
commit f783e10dc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -22,7 +23,7 @@
Usage: Usage:
./transducer_stateless/export.py \ ./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -33,20 +34,19 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/aishell/ASR
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 1 \ --max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model --lang-dir data/lang_char
""" """
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
@ -56,6 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, str2bool
@ -91,10 +92,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="The lang dir",
) )
parser.add_argument( parser.add_argument(
@ -194,12 +195,10 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py params.blank_id = 0
params.blank_id = sp.piece_to_id("<blk>") params.vocab_size = max(lexicon.tokens) + 1
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)