mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Fix decode.py
This commit is contained in:
parent
22e9d837f8
commit
e9ccc0b073
@ -27,22 +27,9 @@ def get_args():
|
|||||||
help="""Path to the input text.
|
help="""Path to the input text.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--normalize",
|
|
||||||
action='store_true',
|
|
||||||
help="""Whether to normalize the text.
|
|
||||||
True to normalize the text to upper and remove all punctuation.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def simple_cleanup(text: str) -> str:
|
|
||||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
|
||||||
text = text.translate(table)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_punc_to_upper(text: str) -> str:
|
def remove_punc_to_upper(text: str) -> str:
|
||||||
text = text.replace("‘", "'")
|
text = text.replace("‘", "'")
|
||||||
text = text.replace("’", "'")
|
text = text.replace("’", "'")
|
||||||
@ -62,10 +49,7 @@ def main():
|
|||||||
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
|
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
while line:
|
while line:
|
||||||
if args.normalize:
|
print(remove_punc_to_upper(line))
|
||||||
print(remove_punc_to_upper(line))
|
|
||||||
else:
|
|
||||||
print(simple_cleanup(line))
|
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
|
|
||||||
|
|
@ -20,6 +20,11 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
def simple_cleanup(text: str) -> str:
|
||||||
|
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||||
|
text = text.translate(table)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
# Assign text of the supervisions and remove unnecessary entries.
|
# Assign text of the supervisions and remove unnecessary entries.
|
||||||
def main():
|
def main():
|
||||||
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
||||||
@ -28,7 +33,7 @@ def main():
|
|||||||
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
cut = json.loads(line)
|
cut = json.loads(line)
|
||||||
cut["supervisions"][0]["text"] = cut["supervisions"][0]["custom"]["texts"][0]
|
cut["supervisions"][0]["text"] = simple_cleanup(cut["supervisions"][0]["custom"]["texts"][0])
|
||||||
del cut["supervisions"][0]["custom"]
|
del cut["supervisions"][0]["custom"]
|
||||||
del cut["custom"]
|
del cut["custom"]
|
||||||
fout.write((json.dumps(cut) + "\n").encode())
|
fout.write((json.dumps(cut) + "\n").encode())
|
||||||
|
@ -221,7 +221,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
if [ ! -f data/texts ]; then
|
if [ ! -f data/texts ]; then
|
||||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
||||||
| ./local/prepare_text.py --normalize > data/texts
|
| ./local/norm_text.py > data/texts
|
||||||
fi
|
fi
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
@ -244,8 +244,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
|||||||
log "Stage 10: Train BPE model for unnormalized text"
|
log "Stage 10: Train BPE model for unnormalized text"
|
||||||
if [ ! -f data/punc_texts ]; then
|
if [ ! -f data/punc_texts ]; then
|
||||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
|
||||||
| ./local/prepare_text.py > data/punc_texts
|
|
||||||
fi
|
fi
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
new_vacab_size = $(($vocab_size + 256))
|
new_vacab_size = $(($vocab_size + 256))
|
||||||
|
@ -119,11 +119,7 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from text_normalization import (
|
from text_normalization import remove_punc_to_upper,
|
||||||
simple_normalization,
|
|
||||||
decoding_normalization,
|
|
||||||
word_normalization,
|
|
||||||
)
|
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -141,7 +137,6 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
@ -222,9 +217,6 @@ def get_parser():
|
|||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -250,16 +242,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ngram-lm-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.01,
|
|
||||||
help="""
|
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
|
||||||
It specifies the scale for n-gram LM scores.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -310,6 +292,14 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-with-punctuation",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Set to True, if the model was trained on texts with casing
|
||||||
|
and punctuation."""
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--post-normalization",
|
"--post-normalization",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -492,8 +482,6 @@ def decode_one_batch(
|
|||||||
if "nbest" in params.decoding_method:
|
if "nbest" in params.decoding_method:
|
||||||
key += f"_num_paths_{params.num_paths}_"
|
key += f"_num_paths_{params.num_paths}_"
|
||||||
key += f"nbest_scale_{params.nbest_scale}"
|
key += f"nbest_scale_{params.nbest_scale}"
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
else:
|
else:
|
||||||
@ -573,6 +561,16 @@ def decode_dataset(
|
|||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
this_batch = []
|
||||||
|
if params.post_normalization and params.train_with_punctuation:
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = remove_punc_to_upper(ref_text).split()
|
||||||
|
hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[f"{name}_norm"].extend(this_batch)
|
||||||
|
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
@ -584,17 +582,6 @@ def decode_dataset(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def post_processing(
|
|
||||||
results: List[Tuple[str, List[str], List[str]]],
|
|
||||||
) -> List[Tuple[str, List[str], List[str]]]:
|
|
||||||
new_results = []
|
|
||||||
for key, ref, hyp in results:
|
|
||||||
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
|
||||||
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
|
||||||
new_results.append((key, new_ref, new_hyp))
|
|
||||||
return new_results
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
@ -605,8 +592,6 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
if test_set_name == "giga-dev" or test_set_name == "giga-test":
|
|
||||||
results = post_processing(results)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -656,7 +641,6 @@ def main():
|
|||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"fast_beam_search_nbest",
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_LG",
|
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
@ -684,8 +668,6 @@ def main():
|
|||||||
if "nbest" in params.decoding_method:
|
if "nbest" in params.decoding_method:
|
||||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
params.suffix += f"-num-paths-{params.num_paths}"
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
@ -798,21 +780,9 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
word_table = lexicon.word_table
|
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
|
||||||
logging.info(f"Loading {lg_filename}")
|
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
|
||||||
torch.load(lg_filename, map_location=device)
|
|
||||||
)
|
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
|
||||||
else:
|
|
||||||
word_table = None
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
word_table = None
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -821,37 +791,23 @@ def main():
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
libriheavy = LibriHeavyAsrDataModule(args)
|
libriheavy = LibriHeavyAsrDataModule(args)
|
||||||
|
|
||||||
def add_texts(c: Cut):
|
def normalize_text(c: Cut):
|
||||||
text = c.supervisions[0].text
|
text = remove_punc_to_upper(c.supervisions[0].text)
|
||||||
c.supervisions[0].texts = [text]
|
c.supervisions[0].text = text
|
||||||
return c
|
return c
|
||||||
|
|
||||||
test_clean_cuts = libriheavy.test_clean_cuts()
|
test_clean_cuts = libriheavy.test_clean_cuts()
|
||||||
test_other_cuts = libriheavy.test_other_cuts()
|
test_other_cuts = libriheavy.test_other_cuts()
|
||||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
|
||||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
|
||||||
|
|
||||||
ls_test_clean_cuts = ls_test_clean_cuts.map(add_texts)
|
if not params.train_with_punctuation:
|
||||||
ls_test_other_cuts = ls_test_other_cuts.map(add_texts)
|
test_clean_cuts = test_clean_cuts.map(normalize_text)
|
||||||
|
test_other_cuts = test_other_cuts.map(normalize_text)
|
||||||
giga_dev = libriheavy.gigaspeech_dev_cuts()
|
|
||||||
giga_test = libriheavy.gigaspeech_test_cuts()
|
|
||||||
giga_dev = giga_dev.map(add_texts)
|
|
||||||
giga_test = giga_test.map(add_texts)
|
|
||||||
|
|
||||||
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
|
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
|
||||||
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
|
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
|
||||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
|
||||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
|
||||||
|
|
||||||
giga_dev_dl = libriheavy.test_dataloaders(giga_dev)
|
test_sets = ["test-clean", "test-other"]
|
||||||
giga_test_dl = libriheavy.test_dataloaders(giga_test)
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
# test_sets = ["libriheavy-test-clean", "libriheavy-test-other", "librispeech-test-clean", "librispeech-test-other"]
|
|
||||||
# test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
|
|
||||||
|
|
||||||
test_sets = ["giga-test", "giga-dev"]
|
|
||||||
test_dl = [giga_test_dl, giga_dev_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
@ -869,39 +825,6 @@ def main():
|
|||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.post_normalization:
|
|
||||||
params.suffix += "-post-normalization"
|
|
||||||
|
|
||||||
new_res = {}
|
|
||||||
for k in results_dict:
|
|
||||||
new_ans = []
|
|
||||||
for item in results_dict[k]:
|
|
||||||
id, ref, hyp = item
|
|
||||||
if "librispeech" in test_set:
|
|
||||||
hyp = decoding_normalization(" ".join(hyp)).split()
|
|
||||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
|
||||||
hyp = " ".join(hyp).split()
|
|
||||||
hyp = [w for w in hyp if w != ""]
|
|
||||||
else:
|
|
||||||
hyp = decoding_normalization(" ".join(hyp)).split()
|
|
||||||
hyp = [w.upper() for w in hyp]
|
|
||||||
hyp = " ".join(hyp).split()
|
|
||||||
hyp = [w for w in hyp if w != ""]
|
|
||||||
|
|
||||||
ref = decoding_normalization(" ".join(ref)).split()
|
|
||||||
ref = [w.upper() for w in ref]
|
|
||||||
ref = " ".join(ref).split()
|
|
||||||
ref = [w for w in ref if w != ""]
|
|
||||||
|
|
||||||
new_ans.append((id, ref, hyp))
|
|
||||||
new_res[k] = new_ans
|
|
||||||
|
|
||||||
save_results(
|
|
||||||
params=params,
|
|
||||||
test_set_name=test_set,
|
|
||||||
results_dict=new_res,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
52
egs/libriheavy/ASR/zipformer/text_normalization.py
Normal file
52
egs/libriheavy/ASR/zipformer/text_normalization.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from num2words import num2words
|
||||||
|
|
||||||
|
|
||||||
|
def remove_punc_to_upper(text: str) -> str:
|
||||||
|
text = text.replace("‘", "'")
|
||||||
|
text = text.replace("’", "'")
|
||||||
|
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||||
|
s_list = [x.upper() if x in tokens else " " for x in text]
|
||||||
|
s = " ".join("".join(s_list).split()).strip()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def word_normalization(word: str) -> str:
|
||||||
|
# 1. Use full word for some abbreviation
|
||||||
|
# 2. Convert digits to english words
|
||||||
|
# 3. Convert ordinal number to english words
|
||||||
|
if word == "MRS":
|
||||||
|
return "MISSUS"
|
||||||
|
if word == "MR":
|
||||||
|
return "MISTER"
|
||||||
|
if word == "ST":
|
||||||
|
return "SAINT"
|
||||||
|
if word == "ECT":
|
||||||
|
return "ET CETERA"
|
||||||
|
|
||||||
|
if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH
|
||||||
|
word = num2words(word[:-2], to="ordinal")
|
||||||
|
word = word.replace("-", " ")
|
||||||
|
|
||||||
|
if word.isnumeric():
|
||||||
|
num = int(word)
|
||||||
|
if num > 1500 and num < 2030:
|
||||||
|
word = num2words(word, to="year")
|
||||||
|
else:
|
||||||
|
word = num2words(word)
|
||||||
|
word = word.replace("-", " ")
|
||||||
|
return word.upper()
|
||||||
|
|
||||||
|
|
||||||
|
def text_normalization(text: str) -> str:
|
||||||
|
text = text.upper()
|
||||||
|
return " ".join([word_normalization(x) for x in text.split()])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
assert simple_cleanup("I like this 《book>") == "I like this <book>"
|
||||||
|
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
|
||||||
|
assert (
|
||||||
|
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
|
||||||
|
== "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER"
|
||||||
|
)
|
@ -77,6 +77,7 @@ from model import AsrModel
|
|||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling import Conv2dSubsampling
|
||||||
|
from text_normalization import remove_punc_to_upper
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -1188,16 +1189,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
def normalize_text(c: Cut):
|
def normalize_text(c: Cut):
|
||||||
text = c.supervisions[0].text
|
text = remove_punc_to_upper(c.supervisions[0].text)
|
||||||
if params.train_with_punctuation:
|
|
||||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
|
||||||
text = text.translate(table)
|
|
||||||
else:
|
|
||||||
text = text.replace("‘", "'")
|
|
||||||
text = text.replace("’", "'")
|
|
||||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
|
||||||
text_list = [x.upper() if x in tokens else " " for x in text]
|
|
||||||
text = " ".join("".join(text_list).split()).strip()
|
|
||||||
c.supervisions[0].text = text
|
c.supervisions[0].text = text
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@ -1245,7 +1237,9 @@ def run(rank, world_size, args):
|
|||||||
train_cuts += libriheavy.train_medium_cuts()
|
train_cuts += libriheavy.train_medium_cuts()
|
||||||
if params.subset == "L":
|
if params.subset == "L":
|
||||||
train_cuts += libriheavy.train_large_cuts()
|
train_cuts += libriheavy.train_large_cuts()
|
||||||
train_cuts = train_cuts.map(normalize_text)
|
|
||||||
|
if not params.train_with_punctuation:
|
||||||
|
train_cuts = train_cuts.map(normalize_text)
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
@ -1261,7 +1255,9 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = libriheavy.dev_cuts()
|
valid_cuts = libriheavy.dev_cuts()
|
||||||
valid_cuts = valid_cuts.map(normalize_text)
|
|
||||||
|
if not params.train_with_punctuation:
|
||||||
|
valid_cuts = valid_cuts.map(normalize_text)
|
||||||
|
|
||||||
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ six
|
|||||||
git+https://github.com/lhotse-speech/lhotse
|
git+https://github.com/lhotse-speech/lhotse
|
||||||
kaldilm==1.11
|
kaldilm==1.11
|
||||||
kaldialign==0.7.1
|
kaldialign==0.7.1
|
||||||
|
num2words
|
||||||
sentencepiece==0.1.96
|
sentencepiece==0.1.96
|
||||||
tensorboard==2.8.0
|
tensorboard==2.8.0
|
||||||
typeguard==2.13.3
|
typeguard==2.13.3
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
kaldifst
|
kaldifst
|
||||||
kaldilm
|
kaldilm
|
||||||
kaldialign
|
kaldialign
|
||||||
|
num2words
|
||||||
sentencepiece>=0.1.96
|
sentencepiece>=0.1.96
|
||||||
tensorboard
|
tensorboard
|
||||||
typeguard
|
typeguard
|
||||||
|
Loading…
x
Reference in New Issue
Block a user