from local

This commit is contained in:
dohe0342 2023-02-02 19:07:21 +09:00
parent d936ec48a9
commit 8f19e1301e
2 changed files with 13 additions and 1 deletions

View File

@ -423,7 +423,7 @@ def decode_dataset(
logging.info(f"decoding {batch_idx} th batch")
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
'''
hyps_dict = decode_one_batch(
params=params,
model=model,
@ -435,6 +435,18 @@ def decode_dataset(
eos_id=eos_id,
token_dict=token_dict,
)
'''
decode_one_batch_greedy(
params=params,
model=model,
HLG=HLG,
H=H,
batch=batch,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
token_dict=token_dict,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []