This commit is contained in:
AmirHussein96 2025-09-13 10:40:45 -04:00
parent 947ae0a73c
commit 229883e828

View File

@ -72,7 +72,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from train_st import add_model_arguments, get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -529,7 +529,7 @@ def decode_dataset(
word_table=word_table,
batch=batch,
)
#breakpoint()
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
@ -540,7 +540,7 @@ def decode_dataset(
this_batch.append((cut_id, ref_words, ref_words_tgt, hyp_words))
results[name].extend(this_batch)
#breakpoint()
num_cuts += len(texts)
if batch_idx % log_interval == 0: