add d file

This commit is contained in:
marcoyang 2023-01-04 10:50:59 +08:00
parent d0eb9b1912
commit eb25b173dc

View File

@ -96,6 +96,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
DecodingResults,
@ -167,6 +168,13 @@ def get_parser():
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
@ -286,6 +294,8 @@ def decode_one_batch(
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if isinstance(encoder_out, list):
encoder_out = encoder_out[-1] # the last item is final output
hyps = []
if params.decoding_method == "fast_beam_search":
@ -345,12 +355,10 @@ def decode_one_batch(
res = DecodingResults(hyps=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp(
decoding_method=params.decoding_method,
res=res,
sp=sp,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
if params.decoding_method == "greedy_search":
@ -533,6 +541,7 @@ def main():
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
import pdb; pdb.set_trace()
params = get_params()
params.update(vars(args))
@ -669,6 +678,9 @@ def main():
else:
decoding_graph = None
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")