Do some changes for aishell/ASR/transducer stateless/export.py (#347)

* do some changes for aishell/ASR/transducer_stateless/export.py
This commit is contained in:
Mingshuang Luo 2022-05-07 11:09:31 +08:00 committed by GitHub
parent c059ef3169
commit f783e10dc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -22,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
@ -33,20 +34,19 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
cd /path/to/egs/aishell/ASR
./transducer_stateless/decode.py \
--exp-dir ./transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
@ -56,6 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
@ -91,10 +92,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
@ -194,12 +195,10 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
lexicon = Lexicon(params.lang_dir)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)