merge with master

This commit is contained in:
pkufool 2024-06-21 18:05:52 +08:00
parent 3059eb4511
commit 19f88482be
5 changed files with 73 additions and 43 deletions

View File

@ -556,18 +556,14 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
@ -577,9 +573,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:

View File

@ -239,12 +239,22 @@ def add_model_arguments(parser: argparse.ArgumentParser):
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
parser.add_argument(
"--use-transducer",
type=str2bool,
default=True,
help="If True, use Transducer head.",
)
parser.add_argument(
"--use-ctc",
type=str2bool,
default=False,
help="If True, use CTC head.",
)
def add_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--world-size",
type=int,
@ -302,16 +312,6 @@ def get_parser():
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--base-lr", type=float, default=0.045, help="The base learning rate."
)
@ -379,6 +379,13 @@ def get_parser():
with this parameter before adding to the final loss.""",
)
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -444,6 +451,24 @@ def get_parser():
help="Whether to use half precision training.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
add_training_arguments(parser)
add_model_arguments(parser)
return parser
@ -1154,26 +1179,6 @@ def run(rank, world_size, args):
# )
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -28,5 +28,6 @@ multi_quantization
onnx
onnxmltools
onnxruntime
pypinyin
kaldifst
kaldi-decoder

View File

@ -3,6 +3,7 @@ kaldilm
kaldialign
num2words
kaldi-decoder
pypinyin
sentencepiece>=0.1.96
pypinyin==0.50.0
tensorboard

View File

@ -28,6 +28,7 @@ from icefall.utils import (
encode_supervisions,
get_texts,
make_pad_mask,
text_to_pinyin,
)
@ -163,3 +164,31 @@ def test_add_eos():
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
)
assert str(ragged_eos) == str(expected)
def test_text_to_pinyin():
txt = "想吃KFC"
r = text_to_pinyin(txt, mode="full_with_tone")
assert " ".join(r) == "xiǎng chī KFC"
r = text_to_pinyin(txt, mode="full_with_tone", errors="split")
assert " ".join(r) == "xiǎng chī K F C"
r = text_to_pinyin(txt, mode="full_no_tone", errors="default")
assert " ".join(r) == "xiang chi KFC"
r = text_to_pinyin(txt, mode="full_no_tone", errors="split")
assert " ".join(r) == "xiang chi K F C"
r = text_to_pinyin(txt, mode="partial_with_tone")
assert " ".join(r) == "x iǎng ch ī KFC"
r = text_to_pinyin(txt, mode="partial_with_tone", errors="split")
assert " ".join(r) == "x iǎng ch ī K F C"
r = text_to_pinyin(txt, mode="partial_no_tone", errors="default")
assert " ".join(r) == "x iang ch i KFC"
r = text_to_pinyin(txt, mode="partial_no_tone", errors="split")
assert " ".join(r) == "x iang ch i K F C"