Fix decode.py

This commit is contained in:
pkufool 2023-09-22 15:33:29 +08:00
parent 22e9d837f8
commit e9ccc0b073
8 changed files with 99 additions and 138 deletions

View File

@ -27,22 +27,9 @@ def get_args():
help="""Path to the input text.
""",
)
parser.add_argument(
"--normalize",
action='store_true',
help="""Whether to normalize the text.
True to normalize the text to upper and remove all punctuation.
"""
)
return parser.parse_args()
def simple_cleanup(text: str) -> str:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
return text.strip()
def remove_punc_to_upper(text: str) -> str:
text = text.replace("", "'")
text = text.replace("", "'")
@ -62,10 +49,7 @@ def main():
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
line = f.readline()
while line:
if args.normalize:
print(remove_punc_to_upper(line))
else:
print(simple_cleanup(line))
print(remove_punc_to_upper(line))
line = f.readline()

View File

@ -20,6 +20,11 @@ import json
import sys
from pathlib import Path
def simple_cleanup(text: str) -> str:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
return text.strip()
# Assign text of the supervisions and remove unnecessary entries.
def main():
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
@ -28,7 +33,7 @@ def main():
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
for line in fin:
cut = json.loads(line)
cut["supervisions"][0]["text"] = cut["supervisions"][0]["custom"]["texts"][0]
cut["supervisions"][0]["text"] = simple_cleanup(cut["supervisions"][0]["custom"]["texts"][0])
del cut["supervisions"][0]["custom"]
del cut["custom"]
fout.write((json.dumps(cut) + "\n").encode())

View File

@ -221,7 +221,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
if [ ! -f data/texts ]; then
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
| ./local/prepare_text.py --normalize > data/texts
| ./local/norm_text.py > data/texts
fi
for vocab_size in ${vocab_sizes[@]}; do
@ -244,8 +244,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Train BPE model for unnormalized text"
if [ ! -f data/punc_texts ]; then
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
| ./local/prepare_text.py > data/punc_texts
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
fi
for vocab_size in ${vocab_sizes[@]}; do
new_vacab_size = $(($vocab_size + 256))

View File

@ -119,11 +119,7 @@ from beam_search import (
modified_beam_search,
)
from lhotse.cut import Cut
from text_normalization import (
simple_normalization,
decoding_normalization,
word_normalization,
)
from text_normalization import remove_punc_to_upper,
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
@ -141,7 +137,6 @@ from icefall.utils import (
str2bool,
write_error_stats,
)
from gigaspeech_scoring import asr_text_post_processing
LOG_EPS = math.log(1e-10)
@ -222,9 +217,6 @@ def get_parser():
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
@ -250,16 +242,6 @@ def get_parser():
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
@ -310,6 +292,14 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--train-with-punctuation",
type=str2bool,
default=False,
help="""Set to True, if the model was trained on texts with casing
and punctuation."""
)
parser.add_argument(
"--post-normalization",
type=str2bool,
@ -492,8 +482,6 @@ def decode_one_batch(
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
@ -573,6 +561,16 @@ def decode_dataset(
results[name].extend(this_batch)
this_batch = []
if params.post_normalization and params.train_with_punctuation:
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = remove_punc_to_upper(ref_text).split()
hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split()
this_batch.append((cut_id, ref_words, hyp_words))
results[f"{name}_norm"].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
@ -584,17 +582,6 @@ def decode_dataset(
return results
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def save_results(
params: AttributeDict,
test_set_name: str,
@ -605,8 +592,6 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
if test_set_name == "giga-dev" or test_set_name == "giga-test":
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -656,7 +641,6 @@ def main():
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
@ -684,8 +668,6 @@ def main():
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
@ -798,21 +780,9 @@ def main():
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -821,37 +791,23 @@ def main():
args.return_cuts = True
libriheavy = LibriHeavyAsrDataModule(args)
def add_texts(c: Cut):
text = c.supervisions[0].text
c.supervisions[0].texts = [text]
def normalize_text(c: Cut):
text = remove_punc_to_upper(c.supervisions[0].text)
c.supervisions[0].text = text
return c
test_clean_cuts = libriheavy.test_clean_cuts()
test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
ls_test_clean_cuts = ls_test_clean_cuts.map(add_texts)
ls_test_other_cuts = ls_test_other_cuts.map(add_texts)
giga_dev = libriheavy.gigaspeech_dev_cuts()
giga_test = libriheavy.gigaspeech_test_cuts()
giga_dev = giga_dev.map(add_texts)
giga_test = giga_test.map(add_texts)
if not params.train_with_punctuation:
test_clean_cuts = test_clean_cuts.map(normalize_text)
test_other_cuts = test_other_cuts.map(normalize_text)
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
giga_dev_dl = libriheavy.test_dataloaders(giga_dev)
giga_test_dl = libriheavy.test_dataloaders(giga_test)
# test_sets = ["libriheavy-test-clean", "libriheavy-test-other", "librispeech-test-clean", "librispeech-test-other"]
# test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
test_sets = ["giga-test", "giga-dev"]
test_dl = [giga_test_dl, giga_dev_dl]
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
@ -869,39 +825,6 @@ def main():
results_dict=results_dict,
)
if params.post_normalization:
params.suffix += "-post-normalization"
new_res = {}
for k in results_dict:
new_ans = []
for item in results_dict[k]:
id, ref, hyp = item
if "librispeech" in test_set:
hyp = decoding_normalization(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp]
hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""]
else:
hyp = decoding_normalization(" ".join(hyp)).split()
hyp = [w.upper() for w in hyp]
hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""]
ref = decoding_normalization(" ".join(ref)).split()
ref = [w.upper() for w in ref]
ref = " ".join(ref).split()
ref = [w for w in ref if w != ""]
new_ans.append((id, ref, hyp))
new_res[k] = new_ans
save_results(
params=params,
test_set_name=test_set,
results_dict=new_res,
)
logging.info("Done!")

View File

@ -0,0 +1,52 @@
from num2words import num2words
def remove_punc_to_upper(text: str) -> str:
text = text.replace("", "'")
text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s
def word_normalization(word: str) -> str:
# 1. Use full word for some abbreviation
# 2. Convert digits to english words
# 3. Convert ordinal number to english words
if word == "MRS":
return "MISSUS"
if word == "MR":
return "MISTER"
if word == "ST":
return "SAINT"
if word == "ECT":
return "ET CETERA"
if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH
word = num2words(word[:-2], to="ordinal")
word = word.replace("-", " ")
if word.isnumeric():
num = int(word)
if num > 1500 and num < 2030:
word = num2words(word, to="year")
else:
word = num2words(word)
word = word.replace("-", " ")
return word.upper()
def text_normalization(text: str) -> str:
text = text.upper()
return " ".join([word_normalization(x) for x in text.split()])
if __name__ == "__main__":
assert simple_cleanup("I like this 《book>") == "I like this <book>"
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
assert (
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
== "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER"
)

View File

@ -77,6 +77,7 @@ from model import AsrModel
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from text_normalization import remove_punc_to_upper
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -1188,16 +1189,7 @@ def run(rank, world_size, args):
def normalize_text(c: Cut):
text = c.supervisions[0].text
if params.train_with_punctuation:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
else:
text = text.replace("", "'")
text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
text_list = [x.upper() if x in tokens else " " for x in text]
text = " ".join("".join(text_list).split()).strip()
text = remove_punc_to_upper(c.supervisions[0].text)
c.supervisions[0].text = text
return c
@ -1245,7 +1237,9 @@ def run(rank, world_size, args):
train_cuts += libriheavy.train_medium_cuts()
if params.subset == "L":
train_cuts += libriheavy.train_large_cuts()
train_cuts = train_cuts.map(normalize_text)
if not params.train_with_punctuation:
train_cuts = train_cuts.map(normalize_text)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1261,7 +1255,9 @@ def run(rank, world_size, args):
)
valid_cuts = libriheavy.dev_cuts()
valid_cuts = valid_cuts.map(normalize_text)
if not params.train_with_punctuation:
valid_cuts = valid_cuts.map(normalize_text)
valid_dl = libriheavy.valid_dataloaders(valid_cuts)

View File

@ -17,6 +17,7 @@ six
git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11
kaldialign==0.7.1
num2words
sentencepiece==0.1.96
tensorboard==2.8.0
typeguard==2.13.3

View File

@ -1,6 +1,7 @@
kaldifst
kaldilm
kaldialign
num2words
sentencepiece>=0.1.96
tensorboard
typeguard