From c08c6ae0ecbfea8fba2685878e58af6287c303e9 Mon Sep 17 00:00:00 2001 From: Guanbo Wang Date: Tue, 3 May 2022 22:15:55 +0000 Subject: [PATCH] gigaspeech decode --- .../ASR/pruned_transducer_stateless2/decode.py | 18 ++++++++++++++---- .../gigaspeech_scoring.py | 1 + 2 files changed, 15 insertions(+), 4 deletions(-) create mode 120000 egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 3ae75cef3..e5810749b 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -74,6 +74,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) +from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model from icefall.checkpoint import ( @@ -200,6 +201,17 @@ def get_parser(): return parser +def post_processing( + results: List[Tuple[List[str], List[str]]], +) -> List[Tuple[List[str], List[str]]]: + new_results = [] + for ref, hyp in results: + new_ref = asr_text_post_processing(' '.join(ref)).split() + new_hyp = asr_text_post_processing(' '.join(hyp)).split() + new_results.append((new_ref, new_hyp)) + return new_results + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -354,10 +366,7 @@ def decode_dataset( except TypeError: num_batches = "?" - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 + log_interval = 100 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -401,6 +410,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = post_processing(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py new file mode 120000 index 000000000..a6a4d12b1 --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py @@ -0,0 +1 @@ +../conformer_ctc/gigaspeech_scoring.py \ No newline at end of file