Fix decode.py

This commit is contained in:
pkufool 2023-09-22 15:33:29 +08:00
parent 22e9d837f8
commit e9ccc0b073
8 changed files with 99 additions and 138 deletions

View File

@ -27,22 +27,9 @@ def get_args():
help="""Path to the input text. help="""Path to the input text.
""", """,
) )
parser.add_argument(
"--normalize",
action='store_true',
help="""Whether to normalize the text.
True to normalize the text to upper and remove all punctuation.
"""
)
return parser.parse_args() return parser.parse_args()
def simple_cleanup(text: str) -> str:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
return text.strip()
def remove_punc_to_upper(text: str) -> str: def remove_punc_to_upper(text: str) -> str:
text = text.replace("", "'") text = text.replace("", "'")
text = text.replace("", "'") text = text.replace("", "'")
@ -62,10 +49,7 @@ def main():
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer) sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
line = f.readline() line = f.readline()
while line: while line:
if args.normalize: print(remove_punc_to_upper(line))
print(remove_punc_to_upper(line))
else:
print(simple_cleanup(line))
line = f.readline() line = f.readline()

View File

@ -20,6 +20,11 @@ import json
import sys import sys
from pathlib import Path from pathlib import Path
def simple_cleanup(text: str) -> str:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
return text.strip()
# Assign text of the supervisions and remove unnecessary entries. # Assign text of the supervisions and remove unnecessary entries.
def main(): def main():
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
@ -28,7 +33,7 @@ def main():
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
for line in fin: for line in fin:
cut = json.loads(line) cut = json.loads(line)
cut["supervisions"][0]["text"] = cut["supervisions"][0]["custom"]["texts"][0] cut["supervisions"][0]["text"] = simple_cleanup(cut["supervisions"][0]["custom"]["texts"][0])
del cut["supervisions"][0]["custom"] del cut["supervisions"][0]["custom"]
del cut["custom"] del cut["custom"]
fout.write((json.dumps(cut) + "\n").encode()) fout.write((json.dumps(cut) + "\n").encode())

View File

@ -221,7 +221,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
if [ ! -f data/texts ]; then if [ ! -f data/texts ]; then
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
| ./local/prepare_text.py --normalize > data/texts | ./local/norm_text.py > data/texts
fi fi
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
@ -244,8 +244,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Train BPE model for unnormalized text" log "Stage 10: Train BPE model for unnormalized text"
if [ ! -f data/punc_texts ]; then if [ ! -f data/punc_texts ]; then
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
| ./local/prepare_text.py > data/punc_texts
fi fi
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
new_vacab_size = $(($vocab_size + 256)) new_vacab_size = $(($vocab_size + 256))

View File

@ -119,11 +119,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from lhotse.cut import Cut from lhotse.cut import Cut
from text_normalization import ( from text_normalization import remove_punc_to_upper,
simple_normalization,
decoding_normalization,
word_normalization,
)
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -141,7 +137,6 @@ from icefall.utils import (
str2bool, str2bool,
write_error_stats, write_error_stats,
) )
from gigaspeech_scoring import asr_text_post_processing
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -222,9 +217,6 @@ def get_parser():
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -250,16 +242,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
@ -310,6 +292,14 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--train-with-punctuation",
type=str2bool,
default=False,
help="""Set to True, if the model was trained on texts with casing
and punctuation."""
)
parser.add_argument( parser.add_argument(
"--post-normalization", "--post-normalization",
type=str2bool, type=str2bool,
@ -492,8 +482,6 @@ def decode_one_batch(
if "nbest" in params.decoding_method: if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_" key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}" key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
else: else:
@ -573,6 +561,16 @@ def decode_dataset(
results[name].extend(this_batch) results[name].extend(this_batch)
this_batch = []
if params.post_normalization and params.train_with_punctuation:
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = remove_punc_to_upper(ref_text).split()
hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split()
this_batch.append((cut_id, ref_words, hyp_words))
results[f"{name}_norm"].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
@ -584,17 +582,6 @@ def decode_dataset(
return results return results
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
@ -605,8 +592,6 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
if test_set_name == "giga-dev" or test_set_name == "giga-test":
results = post_processing(results)
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -656,7 +641,6 @@ def main():
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
@ -684,8 +668,6 @@ def main():
if "nbest" in params.decoding_method: if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else: else:
@ -798,21 +780,9 @@ def main():
model.eval() model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -821,37 +791,23 @@ def main():
args.return_cuts = True args.return_cuts = True
libriheavy = LibriHeavyAsrDataModule(args) libriheavy = LibriHeavyAsrDataModule(args)
def add_texts(c: Cut): def normalize_text(c: Cut):
text = c.supervisions[0].text text = remove_punc_to_upper(c.supervisions[0].text)
c.supervisions[0].texts = [text] c.supervisions[0].text = text
return c return c
test_clean_cuts = libriheavy.test_clean_cuts() test_clean_cuts = libriheavy.test_clean_cuts()
test_other_cuts = libriheavy.test_other_cuts() test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
ls_test_clean_cuts = ls_test_clean_cuts.map(add_texts) if not params.train_with_punctuation:
ls_test_other_cuts = ls_test_other_cuts.map(add_texts) test_clean_cuts = test_clean_cuts.map(normalize_text)
test_other_cuts = test_other_cuts.map(normalize_text)
giga_dev = libriheavy.gigaspeech_dev_cuts()
giga_test = libriheavy.gigaspeech_test_cuts()
giga_dev = giga_dev.map(add_texts)
giga_test = giga_test.map(add_texts)
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts) test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
test_other_dl = libriheavy.test_dataloaders(test_other_cuts) test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
giga_dev_dl = libriheavy.test_dataloaders(giga_dev) test_sets = ["test-clean", "test-other"]
giga_test_dl = libriheavy.test_dataloaders(giga_test) test_dl = [test_clean_dl, test_other_dl]
# test_sets = ["libriheavy-test-clean", "libriheavy-test-other", "librispeech-test-clean", "librispeech-test-other"]
# test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
test_sets = ["giga-test", "giga-dev"]
test_dl = [giga_test_dl, giga_dev_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
@ -869,39 +825,6 @@ def main():
results_dict=results_dict, results_dict=results_dict,
) )
if params.post_normalization:
params.suffix += "-post-normalization"
new_res = {}
for k in results_dict:
new_ans = []
for item in results_dict[k]:
id, ref, hyp = item
if "librispeech" in test_set:
hyp = decoding_normalization(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp]
hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""]
else:
hyp = decoding_normalization(" ".join(hyp)).split()
hyp = [w.upper() for w in hyp]
hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""]
ref = decoding_normalization(" ".join(ref)).split()
ref = [w.upper() for w in ref]
ref = " ".join(ref).split()
ref = [w for w in ref if w != ""]
new_ans.append((id, ref, hyp))
new_res[k] = new_ans
save_results(
params=params,
test_set_name=test_set,
results_dict=new_res,
)
logging.info("Done!") logging.info("Done!")

View File

@ -0,0 +1,52 @@
from num2words import num2words
def remove_punc_to_upper(text: str) -> str:
text = text.replace("", "'")
text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s
def word_normalization(word: str) -> str:
# 1. Use full word for some abbreviation
# 2. Convert digits to english words
# 3. Convert ordinal number to english words
if word == "MRS":
return "MISSUS"
if word == "MR":
return "MISTER"
if word == "ST":
return "SAINT"
if word == "ECT":
return "ET CETERA"
if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH
word = num2words(word[:-2], to="ordinal")
word = word.replace("-", " ")
if word.isnumeric():
num = int(word)
if num > 1500 and num < 2030:
word = num2words(word, to="year")
else:
word = num2words(word)
word = word.replace("-", " ")
return word.upper()
def text_normalization(text: str) -> str:
text = text.upper()
return " ".join([word_normalization(x) for x in text.split()])
if __name__ == "__main__":
assert simple_cleanup("I like this 《book>") == "I like this <book>"
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
assert (
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
== "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER"
)

View File

@ -77,6 +77,7 @@ from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from text_normalization import remove_punc_to_upper
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -1188,16 +1189,7 @@ def run(rank, world_size, args):
def normalize_text(c: Cut): def normalize_text(c: Cut):
text = c.supervisions[0].text text = remove_punc_to_upper(c.supervisions[0].text)
if params.train_with_punctuation:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
else:
text = text.replace("", "'")
text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
text_list = [x.upper() if x in tokens else " " for x in text]
text = " ".join("".join(text_list).split()).strip()
c.supervisions[0].text = text c.supervisions[0].text = text
return c return c
@ -1245,7 +1237,9 @@ def run(rank, world_size, args):
train_cuts += libriheavy.train_medium_cuts() train_cuts += libriheavy.train_medium_cuts()
if params.subset == "L": if params.subset == "L":
train_cuts += libriheavy.train_large_cuts() train_cuts += libriheavy.train_large_cuts()
train_cuts = train_cuts.map(normalize_text)
if not params.train_with_punctuation:
train_cuts = train_cuts.map(normalize_text)
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1261,7 +1255,9 @@ def run(rank, world_size, args):
) )
valid_cuts = libriheavy.dev_cuts() valid_cuts = libriheavy.dev_cuts()
valid_cuts = valid_cuts.map(normalize_text)
if not params.train_with_punctuation:
valid_cuts = valid_cuts.map(normalize_text)
valid_dl = libriheavy.valid_dataloaders(valid_cuts) valid_dl = libriheavy.valid_dataloaders(valid_cuts)

View File

@ -17,6 +17,7 @@ six
git+https://github.com/lhotse-speech/lhotse git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11 kaldilm==1.11
kaldialign==0.7.1 kaldialign==0.7.1
num2words
sentencepiece==0.1.96 sentencepiece==0.1.96
tensorboard==2.8.0 tensorboard==2.8.0
typeguard==2.13.3 typeguard==2.13.3

View File

@ -1,6 +1,7 @@
kaldifst kaldifst
kaldilm kaldilm
kaldialign kaldialign
num2words
sentencepiece>=0.1.96 sentencepiece>=0.1.96
tensorboard tensorboard
typeguard typeguard