mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
gigaspeech decode
This commit is contained in:
parent
d6390fd107
commit
c08c6ae0ec
@ -74,6 +74,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -200,6 +201,17 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def post_processing(
|
||||||
|
results: List[Tuple[List[str], List[str]]],
|
||||||
|
) -> List[Tuple[List[str], List[str]]]:
|
||||||
|
new_results = []
|
||||||
|
for ref, hyp in results:
|
||||||
|
new_ref = asr_text_post_processing(' '.join(ref)).split()
|
||||||
|
new_hyp = asr_text_post_processing(' '.join(hyp)).split()
|
||||||
|
new_results.append((new_ref, new_hyp))
|
||||||
|
return new_results
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -354,10 +366,7 @@ def decode_dataset(
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
log_interval = 100
|
||||||
log_interval = 100
|
|
||||||
else:
|
|
||||||
log_interval = 2
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -401,6 +410,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = post_processing(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
../conformer_ctc/gigaspeech_scoring.py
|
Loading…
x
Reference in New Issue
Block a user