mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Decode with post-processing
This commit is contained in:
parent
6d07cf9245
commit
f485b66d54
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||
# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -28,6 +29,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -170,6 +172,17 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def post_processing(
|
||||
results: List[Tuple[List[str], List[str]]],
|
||||
) -> List[Tuple[List[str], List[str]]]:
|
||||
new_results = []
|
||||
for ref, hyp in results:
|
||||
new_ref = asr_text_post_processing(" ".join(ref))
|
||||
new_hyp = asr_text_post_processing(" ".join(hyp))
|
||||
new_results.append((new_ref, new_hyp))
|
||||
return new_results
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -488,7 +501,7 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
if params.method == "attention-decoder":
|
||||
# Set it to False since there are too many logs.
|
||||
@ -498,6 +511,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = post_processing(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
115
egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py
Executable file
115
egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py
Executable file
@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Jiayu Du
|
||||
# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
conversational_filler = [
|
||||
"UH",
|
||||
"UHH",
|
||||
"UM",
|
||||
"EH",
|
||||
"MM",
|
||||
"HM",
|
||||
"AH",
|
||||
"HUH",
|
||||
"HA",
|
||||
"ER",
|
||||
"OOF",
|
||||
"HEE",
|
||||
"ACH",
|
||||
"EEE",
|
||||
"EW",
|
||||
]
|
||||
unk_tags = ["<UNK>", "<unk>"]
|
||||
gigaspeech_punctuations = [
|
||||
"<COMMA>",
|
||||
"<PERIOD>",
|
||||
"<QUESTIONMARK>",
|
||||
"<EXCLAMATIONPOINT>",
|
||||
]
|
||||
gigaspeech_garbage_utterance_tags = ["<SIL>", "<NOISE>", "<MUSIC>", "<OTHER>"]
|
||||
non_scoring_words = (
|
||||
conversational_filler
|
||||
+ unk_tags
|
||||
+ gigaspeech_punctuations
|
||||
+ gigaspeech_garbage_utterance_tags
|
||||
)
|
||||
|
||||
|
||||
def asr_text_post_processing(text: str) -> str:
|
||||
# 1. convert to uppercase
|
||||
text = text.upper()
|
||||
|
||||
# 2. remove hyphen
|
||||
# "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
|
||||
text = text.replace("-", " ")
|
||||
|
||||
# 3. remove non-scoring words from evaluation
|
||||
remaining_words = []
|
||||
for word in text.split():
|
||||
if word in non_scoring_words:
|
||||
continue
|
||||
remaining_words.append(word)
|
||||
|
||||
return " ".join(remaining_words)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="This script evaluates GigaSpeech ASR result via"
|
||||
"SCTK's tool sclite"
|
||||
)
|
||||
parser.add_argument(
|
||||
"ref",
|
||||
type=str,
|
||||
help="sclite's standard transcription(trn) reference file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"hyp",
|
||||
type=str,
|
||||
help="sclite's standard transcription(trn) hypothesis file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"work_dir",
|
||||
type=str,
|
||||
help="working dir",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isdir(args.work_dir):
|
||||
os.mkdir(args.work_dir)
|
||||
|
||||
REF = os.path.join(args.work_dir, "REF")
|
||||
HYP = os.path.join(args.work_dir, "HYP")
|
||||
RESULT = os.path.join(args.work_dir, "RESULT")
|
||||
|
||||
for io in [(args.ref, REF), (args.hyp, HYP)]:
|
||||
with open(io[0], "r", encoding="utf8") as fi:
|
||||
with open(io[1], "w+", encoding="utf8") as fo:
|
||||
for line in fi:
|
||||
line = line.strip()
|
||||
if line:
|
||||
cols = line.split()
|
||||
text = asr_text_post_processing(" ".join(cols[0:-1]))
|
||||
uttid_field = cols[-1]
|
||||
print(f"{text} {uttid_field}", file=fo)
|
||||
|
||||
# GigaSpeech's uttid comforms to swb
|
||||
os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}")
|
Loading…
x
Reference in New Issue
Block a user