mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
gigaspeech decode
This commit is contained in:
parent
d6390fd107
commit
c08c6ae0ec
@ -74,6 +74,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -200,6 +201,17 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def post_processing(
|
||||
results: List[Tuple[List[str], List[str]]],
|
||||
) -> List[Tuple[List[str], List[str]]]:
|
||||
new_results = []
|
||||
for ref, hyp in results:
|
||||
new_ref = asr_text_post_processing(' '.join(ref)).split()
|
||||
new_hyp = asr_text_post_processing(' '.join(hyp)).split()
|
||||
new_results.append((new_ref, new_hyp))
|
||||
return new_results
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -354,10 +366,7 @@ def decode_dataset(
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 2
|
||||
log_interval = 100
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -401,6 +410,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = post_processing(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
../conformer_ctc/gigaspeech_scoring.py
|
Loading…
x
Reference in New Issue
Block a user