diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index ec5854660..a810bef06 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -28,6 +29,7 @@ import torch import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule from conformer import Conformer +from gigaspeech_scoring import asr_text_post_processing from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -170,6 +172,17 @@ def get_params() -> AttributeDict: return params +def post_processing( + results: List[Tuple[List[str], List[str]]], +) -> List[Tuple[List[str], List[str]]]: + new_results = [] + for ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)) + new_hyp = asr_text_post_processing(" ".join(hyp)) + new_results.append((new_ref, new_hyp)) + return new_results + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -488,7 +501,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): if params.method == "attention-decoder": # Set it to False since there are too many logs. @@ -498,6 +511,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = post_processing(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py new file mode 100755 index 000000000..ef53b77f8 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}")