mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
merge with master
This commit is contained in:
parent
3059eb4511
commit
19f88482be
@ -556,18 +556,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -577,9 +573,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
@ -239,12 +239,22 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
|
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
def get_parser():
|
"--use-transducer",
|
||||||
parser = argparse.ArgumentParser(
|
type=str2bool,
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
default=True,
|
||||||
|
help="If True, use Transducer head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-ctc",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use CTC head.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_training_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -302,16 +312,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_char",
|
|
||||||
help="""The lang dir
|
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=0.045, help="The base learning rate."
|
"--base-lr", type=float, default=0.045, help="The base learning rate."
|
||||||
)
|
)
|
||||||
@ -379,6 +379,13 @@ def get_parser():
|
|||||||
with this parameter before adding to the final loss.""",
|
with this parameter before adding to the final loss.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctc-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for CTC loss.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -444,6 +451,24 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="""The lang dir
|
||||||
|
It contains language related input files such as
|
||||||
|
"lexicon.txt"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_training_arguments(parser)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1154,26 +1179,6 @@ def run(rank, world_size, args):
|
|||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
|
||||||
# where T is the number of feature frames after subsampling
|
|
||||||
# and S is the number of tokens in the utterance
|
|
||||||
|
|
||||||
# In ./zipformer.py, the conv module uses the following expression
|
|
||||||
# for subsampling
|
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
|
||||||
tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
|
|
||||||
|
|
||||||
if T < len(tokens):
|
|
||||||
logging.warning(
|
|
||||||
f"Exclude cut with ID {c.id} from training. "
|
|
||||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
|
||||||
f"Number of frames (after subsampling): {T}. "
|
|
||||||
f"Text: {c.supervisions[0].text}. "
|
|
||||||
f"Tokens: {tokens}. "
|
|
||||||
f"Number of tokens: {len(tokens)}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
@ -28,5 +28,6 @@ multi_quantization
|
|||||||
onnx
|
onnx
|
||||||
onnxmltools
|
onnxmltools
|
||||||
onnxruntime
|
onnxruntime
|
||||||
|
pypinyin
|
||||||
kaldifst
|
kaldifst
|
||||||
kaldi-decoder
|
kaldi-decoder
|
||||||
|
@ -3,6 +3,7 @@ kaldilm
|
|||||||
kaldialign
|
kaldialign
|
||||||
num2words
|
num2words
|
||||||
kaldi-decoder
|
kaldi-decoder
|
||||||
|
pypinyin
|
||||||
sentencepiece>=0.1.96
|
sentencepiece>=0.1.96
|
||||||
pypinyin==0.50.0
|
pypinyin==0.50.0
|
||||||
tensorboard
|
tensorboard
|
||||||
|
@ -28,6 +28,7 @@ from icefall.utils import (
|
|||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_texts,
|
get_texts,
|
||||||
make_pad_mask,
|
make_pad_mask,
|
||||||
|
text_to_pinyin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -163,3 +164,31 @@ def test_add_eos():
|
|||||||
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
||||||
)
|
)
|
||||||
assert str(ragged_eos) == str(expected)
|
assert str(ragged_eos) == str(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_to_pinyin():
|
||||||
|
txt = "想吃KFC"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="full_with_tone")
|
||||||
|
assert " ".join(r) == "xiǎng chī KFC"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="full_with_tone", errors="split")
|
||||||
|
assert " ".join(r) == "xiǎng chī K F C"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="full_no_tone", errors="default")
|
||||||
|
assert " ".join(r) == "xiang chi KFC"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="full_no_tone", errors="split")
|
||||||
|
assert " ".join(r) == "xiang chi K F C"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="partial_with_tone")
|
||||||
|
assert " ".join(r) == "x iǎng ch ī KFC"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="partial_with_tone", errors="split")
|
||||||
|
assert " ".join(r) == "x iǎng ch ī K F C"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="partial_no_tone", errors="default")
|
||||||
|
assert " ".join(r) == "x iang ch i KFC"
|
||||||
|
|
||||||
|
r = text_to_pinyin(txt, mode="partial_no_tone", errors="split")
|
||||||
|
assert " ".join(r) == "x iang ch i K F C"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user