mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
merge with master
This commit is contained in:
parent
3059eb4511
commit
19f88482be
@ -556,18 +556,14 @@ def save_results(
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
@ -577,9 +573,7 @@ def save_results(
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
|
@ -239,12 +239,22 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
parser.add_argument(
|
||||
"--use-transducer",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If True, use Transducer head.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ctc",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use CTC head.",
|
||||
)
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
@ -302,16 +312,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=0.045, help="The base learning rate."
|
||||
)
|
||||
@ -379,6 +379,13 @@ def get_parser():
|
||||
with this parameter before adding to the final loss.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ctc-loss-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Scale for CTC loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -444,6 +451,24 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
add_training_arguments(parser)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -1154,26 +1179,6 @@ def run(rank, world_size, args):
|
||||
# )
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
# where T is the number of feature frames after subsampling
|
||||
# and S is the number of tokens in the utterance
|
||||
|
||||
# In ./zipformer.py, the conv module uses the following expression
|
||||
# for subsampling
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
|
||||
|
||||
if T < len(tokens):
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. "
|
||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
f"Number of frames (after subsampling): {T}. "
|
||||
f"Text: {c.supervisions[0].text}. "
|
||||
f"Tokens: {tokens}. "
|
||||
f"Number of tokens: {len(tokens)}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
@ -28,5 +28,6 @@ multi_quantization
|
||||
onnx
|
||||
onnxmltools
|
||||
onnxruntime
|
||||
pypinyin
|
||||
kaldifst
|
||||
kaldi-decoder
|
||||
|
@ -3,6 +3,7 @@ kaldilm
|
||||
kaldialign
|
||||
num2words
|
||||
kaldi-decoder
|
||||
pypinyin
|
||||
sentencepiece>=0.1.96
|
||||
pypinyin==0.50.0
|
||||
tensorboard
|
||||
|
@ -28,6 +28,7 @@ from icefall.utils import (
|
||||
encode_supervisions,
|
||||
get_texts,
|
||||
make_pad_mask,
|
||||
text_to_pinyin,
|
||||
)
|
||||
|
||||
|
||||
@ -163,3 +164,31 @@ def test_add_eos():
|
||||
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
||||
)
|
||||
assert str(ragged_eos) == str(expected)
|
||||
|
||||
|
||||
def test_text_to_pinyin():
|
||||
txt = "想吃KFC"
|
||||
|
||||
r = text_to_pinyin(txt, mode="full_with_tone")
|
||||
assert " ".join(r) == "xiǎng chī KFC"
|
||||
|
||||
r = text_to_pinyin(txt, mode="full_with_tone", errors="split")
|
||||
assert " ".join(r) == "xiǎng chī K F C"
|
||||
|
||||
r = text_to_pinyin(txt, mode="full_no_tone", errors="default")
|
||||
assert " ".join(r) == "xiang chi KFC"
|
||||
|
||||
r = text_to_pinyin(txt, mode="full_no_tone", errors="split")
|
||||
assert " ".join(r) == "xiang chi K F C"
|
||||
|
||||
r = text_to_pinyin(txt, mode="partial_with_tone")
|
||||
assert " ".join(r) == "x iǎng ch ī KFC"
|
||||
|
||||
r = text_to_pinyin(txt, mode="partial_with_tone", errors="split")
|
||||
assert " ".join(r) == "x iǎng ch ī K F C"
|
||||
|
||||
r = text_to_pinyin(txt, mode="partial_no_tone", errors="default")
|
||||
assert " ".join(r) == "x iang ch i KFC"
|
||||
|
||||
r = text_to_pinyin(txt, mode="partial_no_tone", errors="split")
|
||||
assert " ".join(r) == "x iang ch i K F C"
|
||||
|
Loading…
x
Reference in New Issue
Block a user