gigaspeech decode

This commit is contained in:
Guanbo Wang 2022-05-03 22:15:55 +00:00
parent d6390fd107
commit c08c6ae0ec
2 changed files with 15 additions and 4 deletions

View File

@ -74,6 +74,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model
from icefall.checkpoint import (
@ -200,6 +201,17 @@ def get_parser():
return parser
def post_processing(
results: List[Tuple[List[str], List[str]]],
) -> List[Tuple[List[str], List[str]]]:
new_results = []
for ref, hyp in results:
new_ref = asr_text_post_processing(' '.join(ref)).split()
new_hyp = asr_text_post_processing(' '.join(hyp)).split()
new_results.append((new_ref, new_hyp))
return new_results
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -354,10 +366,7 @@ def decode_dataset(
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
log_interval = 100
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -401,6 +410,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = post_processing(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")

View File

@ -0,0 +1 @@
../conformer_ctc/gigaspeech_scoring.py