mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-22 08:16:14 +00:00
add d file
This commit is contained in:
parent
d0eb9b1912
commit
eb25b173dc
@ -96,6 +96,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
DecodingResults,
|
DecodingResults,
|
||||||
@ -167,6 +168,13 @@ def get_parser():
|
|||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@ -286,6 +294,8 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
if isinstance(encoder_out, list):
|
||||||
|
encoder_out = encoder_out[-1] # the last item is final output
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -345,12 +355,10 @@ def decode_one_batch(
|
|||||||
res = DecodingResults(hyps=tokens, timestamps=timestamps)
|
res = DecodingResults(hyps=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
decoding_method=params.decoding_method,
|
|
||||||
res=res,
|
res=res,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
frame_shift_ms=params.frame_shift_ms,
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
word_table=word_table,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
@ -533,6 +541,7 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -669,6 +678,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user