mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Decode with post-processing
This commit is contained in:
parent
6d07cf9245
commit
f485b66d54
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||||
|
# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -28,6 +29,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -170,6 +172,17 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def post_processing(
|
||||||
|
results: List[Tuple[List[str], List[str]]],
|
||||||
|
) -> List[Tuple[List[str], List[str]]]:
|
||||||
|
new_results = []
|
||||||
|
for ref, hyp in results:
|
||||||
|
new_ref = asr_text_post_processing(" ".join(ref))
|
||||||
|
new_hyp = asr_text_post_processing(" ".join(hyp))
|
||||||
|
new_results.append((new_ref, new_hyp))
|
||||||
|
return new_results
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -488,7 +501,7 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
if params.method == "attention-decoder":
|
if params.method == "attention-decoder":
|
||||||
# Set it to False since there are too many logs.
|
# Set it to False since there are too many logs.
|
||||||
@ -498,6 +511,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = post_processing(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
115
egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py
Executable file
115
egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py
Executable file
@ -0,0 +1,115 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Jiayu Du
|
||||||
|
# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
conversational_filler = [
|
||||||
|
"UH",
|
||||||
|
"UHH",
|
||||||
|
"UM",
|
||||||
|
"EH",
|
||||||
|
"MM",
|
||||||
|
"HM",
|
||||||
|
"AH",
|
||||||
|
"HUH",
|
||||||
|
"HA",
|
||||||
|
"ER",
|
||||||
|
"OOF",
|
||||||
|
"HEE",
|
||||||
|
"ACH",
|
||||||
|
"EEE",
|
||||||
|
"EW",
|
||||||
|
]
|
||||||
|
unk_tags = ["<UNK>", "<unk>"]
|
||||||
|
gigaspeech_punctuations = [
|
||||||
|
"<COMMA>",
|
||||||
|
"<PERIOD>",
|
||||||
|
"<QUESTIONMARK>",
|
||||||
|
"<EXCLAMATIONPOINT>",
|
||||||
|
]
|
||||||
|
gigaspeech_garbage_utterance_tags = ["<SIL>", "<NOISE>", "<MUSIC>", "<OTHER>"]
|
||||||
|
non_scoring_words = (
|
||||||
|
conversational_filler
|
||||||
|
+ unk_tags
|
||||||
|
+ gigaspeech_punctuations
|
||||||
|
+ gigaspeech_garbage_utterance_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def asr_text_post_processing(text: str) -> str:
|
||||||
|
# 1. convert to uppercase
|
||||||
|
text = text.upper()
|
||||||
|
|
||||||
|
# 2. remove hyphen
|
||||||
|
# "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
|
||||||
|
text = text.replace("-", " ")
|
||||||
|
|
||||||
|
# 3. remove non-scoring words from evaluation
|
||||||
|
remaining_words = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in non_scoring_words:
|
||||||
|
continue
|
||||||
|
remaining_words.append(word)
|
||||||
|
|
||||||
|
return " ".join(remaining_words)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="This script evaluates GigaSpeech ASR result via"
|
||||||
|
"SCTK's tool sclite"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"ref",
|
||||||
|
type=str,
|
||||||
|
help="sclite's standard transcription(trn) reference file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"hyp",
|
||||||
|
type=str,
|
||||||
|
help="sclite's standard transcription(trn) hypothesis file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"work_dir",
|
||||||
|
type=str,
|
||||||
|
help="working dir",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.isdir(args.work_dir):
|
||||||
|
os.mkdir(args.work_dir)
|
||||||
|
|
||||||
|
REF = os.path.join(args.work_dir, "REF")
|
||||||
|
HYP = os.path.join(args.work_dir, "HYP")
|
||||||
|
RESULT = os.path.join(args.work_dir, "RESULT")
|
||||||
|
|
||||||
|
for io in [(args.ref, REF), (args.hyp, HYP)]:
|
||||||
|
with open(io[0], "r", encoding="utf8") as fi:
|
||||||
|
with open(io[1], "w+", encoding="utf8") as fo:
|
||||||
|
for line in fi:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
cols = line.split()
|
||||||
|
text = asr_text_post_processing(" ".join(cols[0:-1]))
|
||||||
|
uttid_field = cols[-1]
|
||||||
|
print(f"{text} {uttid_field}", file=fo)
|
||||||
|
|
||||||
|
# GigaSpeech's uttid comforms to swb
|
||||||
|
os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}")
|
Loading…
x
Reference in New Issue
Block a user