diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md index e391860b9..10e38c501 100644 --- a/egs/swbd/ASR/README.md +++ b/egs/swbd/ASR/README.md @@ -11,10 +11,8 @@ Switchboard is a collection of about 2,400 two-sided telephone conversations amo ## TODO List - [x] Incorporate Lhotse for data processing - [x] Further text normalization -- [ ] Refer to Global Mapping Rules when computing Word Error Rate - [x] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset -- [ ] Switchboard transcript train/dev split for LM training -- [ ] Fisher corpus LDC2004T19 LDC2005T19 LDC2004S13 LDC2005S13 for LM training +- [x] Switchboard transcript train/dev split for LM training ## Performance Record | | eval2000 | rt03 | @@ -30,3 +28,5 @@ The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ct A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). + +The `sclite_scoring.py` is from the GigaSpeech recipe for post processing and glm-like scoring, which is definitely not an elegant stuff to do. diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 68b84a5f8..396a264a2 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -30,6 +30,8 @@ import torch.nn as nn from asr_datamodule import SwitchBoardAsrDataModule from conformer import Conformer +from sclite_scoring import asr_text_post_processing + from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.decode import ( @@ -233,6 +235,17 @@ def get_params() -> AttributeDict: return params +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -591,6 +604,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = post_processing(results) results = ( sorted(list(filter(lambda x: x[0].startswith(prefix), results))) if subset != "avg" @@ -605,7 +619,11 @@ def save_results( errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{subset}-{key}", results, enable_log=enable_log + f, + f"{test_set_name}-{subset}-{key}", + results, + enable_log=enable_log, + sclite_mode=True, ) test_set_wers[key] = wer diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py new file mode 100755 index 000000000..0ad1fc2e9 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", + "MHM", + "HUM", + "AW", + "OH", +] +unk_tags = ["", ""] +switchboard_garbage_utterance_tags = [ + "[LAUGHTER]", + "[NOISE]", + "[VOCALIZED-NOISE]", + "[SILENCE]" +] +non_scoring_words = ( + conversational_filler + unk_tags + switchboard_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index b07260f5b..f3299353b 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -108,6 +108,10 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz + sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ + $eval2000_dir/LDC2002T43/reference/hub5e00.english.000405.stm > data/manifests/eval2000/stm + cp $eval2000_dir/LDC2002T43/reference/en20000405_hub5.glm $dir/glm + # ./local/rt03_data_prep.sh $rt03_dir # normalize eval2000 and rt03 texts by