mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Fix decode.py
This commit is contained in:
parent
22e9d837f8
commit
e9ccc0b073
@ -27,22 +27,9 @@ def get_args():
|
||||
help="""Path to the input text.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action='store_true',
|
||||
help="""Whether to normalize the text.
|
||||
True to normalize the text to upper and remove all punctuation.
|
||||
"""
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def simple_cleanup(text: str) -> str:
|
||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||
text = text.translate(table)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def remove_punc_to_upper(text: str) -> str:
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
@ -62,10 +49,7 @@ def main():
|
||||
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
|
||||
line = f.readline()
|
||||
while line:
|
||||
if args.normalize:
|
||||
print(remove_punc_to_upper(line))
|
||||
else:
|
||||
print(simple_cleanup(line))
|
||||
print(remove_punc_to_upper(line))
|
||||
line = f.readline()
|
||||
|
||||
|
@ -20,6 +20,11 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def simple_cleanup(text: str) -> str:
|
||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||
text = text.translate(table)
|
||||
return text.strip()
|
||||
|
||||
# Assign text of the supervisions and remove unnecessary entries.
|
||||
def main():
|
||||
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
||||
@ -28,7 +33,7 @@ def main():
|
||||
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
||||
for line in fin:
|
||||
cut = json.loads(line)
|
||||
cut["supervisions"][0]["text"] = cut["supervisions"][0]["custom"]["texts"][0]
|
||||
cut["supervisions"][0]["text"] = simple_cleanup(cut["supervisions"][0]["custom"]["texts"][0])
|
||||
del cut["supervisions"][0]["custom"]
|
||||
del cut["custom"]
|
||||
fout.write((json.dumps(cut) + "\n").encode())
|
||||
|
@ -221,7 +221,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
if [ ! -f data/texts ]; then
|
||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
||||
| ./local/prepare_text.py --normalize > data/texts
|
||||
| ./local/norm_text.py > data/texts
|
||||
fi
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
@ -244,8 +244,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Train BPE model for unnormalized text"
|
||||
if [ ! -f data/punc_texts ]; then
|
||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
||||
| ./local/prepare_text.py > data/punc_texts
|
||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
|
||||
fi
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
new_vacab_size = $(($vocab_size + 256))
|
||||
|
@ -119,11 +119,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from text_normalization import (
|
||||
simple_normalization,
|
||||
decoding_normalization,
|
||||
word_normalization,
|
||||
)
|
||||
from text_normalization import remove_punc_to_upper,
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -141,7 +137,6 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
@ -222,9 +217,6 @@ def get_parser():
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -250,16 +242,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
@ -310,6 +292,14 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--train-with-punctuation",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Set to True, if the model was trained on texts with casing
|
||||
and punctuation."""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
@ -492,8 +482,6 @@ def decode_one_batch(
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
@ -573,6 +561,16 @@ def decode_dataset(
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
this_batch = []
|
||||
if params.post_normalization and params.train_with_punctuation:
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = remove_punc_to_upper(ref_text).split()
|
||||
hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[f"{name}_norm"].extend(this_batch)
|
||||
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
@ -584,17 +582,6 @@ def decode_dataset(
|
||||
return results
|
||||
|
||||
|
||||
def post_processing(
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
) -> List[Tuple[str, List[str], List[str]]]:
|
||||
new_results = []
|
||||
for key, ref, hyp in results:
|
||||
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||
new_results.append((key, new_ref, new_hyp))
|
||||
return new_results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
@ -605,8 +592,6 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
if test_set_name == "giga-dev" or test_set_name == "giga-test":
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -656,7 +641,6 @@ def main():
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
@ -684,8 +668,6 @@ def main():
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
@ -798,21 +780,9 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -821,37 +791,23 @@ def main():
|
||||
args.return_cuts = True
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
def add_texts(c: Cut):
|
||||
text = c.supervisions[0].text
|
||||
c.supervisions[0].texts = [text]
|
||||
def normalize_text(c: Cut):
|
||||
text = remove_punc_to_upper(c.supervisions[0].text)
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
|
||||
test_clean_cuts = libriheavy.test_clean_cuts()
|
||||
test_other_cuts = libriheavy.test_other_cuts()
|
||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||
|
||||
ls_test_clean_cuts = ls_test_clean_cuts.map(add_texts)
|
||||
ls_test_other_cuts = ls_test_other_cuts.map(add_texts)
|
||||
|
||||
giga_dev = libriheavy.gigaspeech_dev_cuts()
|
||||
giga_test = libriheavy.gigaspeech_test_cuts()
|
||||
giga_dev = giga_dev.map(add_texts)
|
||||
giga_test = giga_test.map(add_texts)
|
||||
if not params.train_with_punctuation:
|
||||
test_clean_cuts = test_clean_cuts.map(normalize_text)
|
||||
test_other_cuts = test_other_cuts.map(normalize_text)
|
||||
|
||||
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
|
||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||
|
||||
giga_dev_dl = libriheavy.test_dataloaders(giga_dev)
|
||||
giga_test_dl = libriheavy.test_dataloaders(giga_test)
|
||||
|
||||
# test_sets = ["libriheavy-test-clean", "libriheavy-test-other", "librispeech-test-clean", "librispeech-test-other"]
|
||||
# test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
|
||||
|
||||
test_sets = ["giga-test", "giga-dev"]
|
||||
test_dl = [giga_test_dl, giga_dev_dl]
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
@ -869,39 +825,6 @@ def main():
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if params.post_normalization:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
new_res = {}
|
||||
for k in results_dict:
|
||||
new_ans = []
|
||||
for item in results_dict[k]:
|
||||
id, ref, hyp = item
|
||||
if "librispeech" in test_set:
|
||||
hyp = decoding_normalization(" ".join(hyp)).split()
|
||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
||||
hyp = " ".join(hyp).split()
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
else:
|
||||
hyp = decoding_normalization(" ".join(hyp)).split()
|
||||
hyp = [w.upper() for w in hyp]
|
||||
hyp = " ".join(hyp).split()
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
|
||||
ref = decoding_normalization(" ".join(ref)).split()
|
||||
ref = [w.upper() for w in ref]
|
||||
ref = " ".join(ref).split()
|
||||
ref = [w for w in ref if w != ""]
|
||||
|
||||
new_ans.append((id, ref, hyp))
|
||||
new_res[k] = new_ans
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=new_res,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
52
egs/libriheavy/ASR/zipformer/text_normalization.py
Normal file
52
egs/libriheavy/ASR/zipformer/text_normalization.py
Normal file
@ -0,0 +1,52 @@
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
def remove_punc_to_upper(text: str) -> str:
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||
s_list = [x.upper() if x in tokens else " " for x in text]
|
||||
s = " ".join("".join(s_list).split()).strip()
|
||||
return s
|
||||
|
||||
|
||||
def word_normalization(word: str) -> str:
|
||||
# 1. Use full word for some abbreviation
|
||||
# 2. Convert digits to english words
|
||||
# 3. Convert ordinal number to english words
|
||||
if word == "MRS":
|
||||
return "MISSUS"
|
||||
if word == "MR":
|
||||
return "MISTER"
|
||||
if word == "ST":
|
||||
return "SAINT"
|
||||
if word == "ECT":
|
||||
return "ET CETERA"
|
||||
|
||||
if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH
|
||||
word = num2words(word[:-2], to="ordinal")
|
||||
word = word.replace("-", " ")
|
||||
|
||||
if word.isnumeric():
|
||||
num = int(word)
|
||||
if num > 1500 and num < 2030:
|
||||
word = num2words(word, to="year")
|
||||
else:
|
||||
word = num2words(word)
|
||||
word = word.replace("-", " ")
|
||||
return word.upper()
|
||||
|
||||
|
||||
def text_normalization(text: str) -> str:
|
||||
text = text.upper()
|
||||
return " ".join([word_normalization(x) for x in text.split()])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
assert simple_cleanup("I like this 《book>") == "I like this <book>"
|
||||
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
|
||||
assert (
|
||||
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
|
||||
== "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER"
|
||||
)
|
@ -77,6 +77,7 @@ from model import AsrModel
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from text_normalization import remove_punc_to_upper
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -1188,16 +1189,7 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
def normalize_text(c: Cut):
|
||||
text = c.supervisions[0].text
|
||||
if params.train_with_punctuation:
|
||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||
text = text.translate(table)
|
||||
else:
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||
text_list = [x.upper() if x in tokens else " " for x in text]
|
||||
text = " ".join("".join(text_list).split()).strip()
|
||||
text = remove_punc_to_upper(c.supervisions[0].text)
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
|
||||
@ -1245,7 +1237,9 @@ def run(rank, world_size, args):
|
||||
train_cuts += libriheavy.train_medium_cuts()
|
||||
if params.subset == "L":
|
||||
train_cuts += libriheavy.train_large_cuts()
|
||||
train_cuts = train_cuts.map(normalize_text)
|
||||
|
||||
if not params.train_with_punctuation:
|
||||
train_cuts = train_cuts.map(normalize_text)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
@ -1261,7 +1255,9 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
valid_cuts = libriheavy.dev_cuts()
|
||||
valid_cuts = valid_cuts.map(normalize_text)
|
||||
|
||||
if not params.train_with_punctuation:
|
||||
valid_cuts = valid_cuts.map(normalize_text)
|
||||
|
||||
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
||||
|
||||
|
@ -17,6 +17,7 @@ six
|
||||
git+https://github.com/lhotse-speech/lhotse
|
||||
kaldilm==1.11
|
||||
kaldialign==0.7.1
|
||||
num2words
|
||||
sentencepiece==0.1.96
|
||||
tensorboard==2.8.0
|
||||
typeguard==2.13.3
|
||||
|
@ -1,6 +1,7 @@
|
||||
kaldifst
|
||||
kaldilm
|
||||
kaldialign
|
||||
num2words
|
||||
sentencepiece>=0.1.96
|
||||
tensorboard
|
||||
typeguard
|
||||
|
Loading…
x
Reference in New Issue
Block a user