replace file

This commit is contained in:
Kinan Martin 2025-04-14 08:27:50 +09:00
parent 1e9bb87305
commit a4be3cb3db
2 changed files with 67 additions and 1149 deletions

View File

@ -103,6 +103,9 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
# import sentencepiece as spm
from tokenizer import Tokenizer
import torch
import torch.nn as nn
from asr_datamodule import MLSEnglishHFAsrDataModule
@ -120,7 +123,7 @@ from beam_search import (
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from tokenizer import Tokenizer
# from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
@ -194,18 +197,25 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
# parser.add_argument(
# "--bpe-model",
# type=str,
# default="data/lang_bpe_500/bpe.model",
# help="Path to the BPE model",
# )
# parser.add_argument(
# "--lang-dir",
# type=Path,
# default="data/lang_bpe_500",
# help="The lang dir containing word table and LG graph",
# )
parser.add_argument(
"--lang-dir",
type=Path,
type=str,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
help="Path to the lang dir with the BPE model (`bpe.model`)",
)
parser.add_argument(
@ -370,23 +380,24 @@ def get_parser():
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def asr_text_post_processing(inp):
return inp
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def decode_one_batch(
params: AttributeDict,
@ -470,10 +481,9 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -485,7 +495,6 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
@ -500,10 +509,9 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
@ -516,19 +524,17 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -536,10 +542,9 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
@ -549,7 +554,7 @@ def decode_one_batch(
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
@ -562,7 +567,7 @@ def decode_one_batch(
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
@ -608,9 +613,8 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.text2word(sp.decode(hyp)))
hyps.append(sp.decode(hyp).split())
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
@ -714,7 +718,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = sp.text2word(ref_text)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -738,6 +742,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -776,7 +781,7 @@ def save_results(
def main():
parser = get_parser()
MLSEnglishHFAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -847,8 +852,6 @@ def main():
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -861,9 +864,13 @@ def main():
logging.info(f"Device: {device}")
sp = Tokenizer.load(Path(args.lang_dir), "bpe")
# sp = spm.SentencePieceProcessor()
# sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/prepare_lang_char.py
sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
@ -1022,9 +1029,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append((sp.encode(line.strip()), 0.0))
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
@ -1038,16 +1045,21 @@ def main():
mls_english_corpus = MLSEnglishHFAsrDataModule(args)
mls_english_corpus.load_hf_dataset("/root/datasets/parler-tts--mls_eng")
# dev_cuts = mls_english_corpus.dev_cuts()
test_cuts = mls_english_corpus.test_cuts()
for subdir in ["valid"]:
# dev_dl = mls_english_corpus.test_dataloaders(dev_cuts)
test_dl = mls_english_corpus.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
# test_sets = ["dev", "test"]
# test_dls = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl = mls_english_corpus.test_dataloaders(
test_cuts,
),
# dl=mls_english_corpus.test_dataloaders(
# getattr(mls_english_corpus, f"{subdir}_cuts")()
# ),
dl=test_dl,
params=params,
model=model,
sp=sp,
@ -1058,22 +1070,12 @@ def main():
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
tot_err = save_results(
save_results(
params=params,
test_set_name=subdir,
test_set_name=test_set,
results_dict=results_dict,
)
# with (
# params.res_dir
# / (
# f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
# f"_{params.avg}_{params.epoch}.cer"
# )
# ).open("w") as fout:
# if len(tot_err) == 1:
# fout.write(f"{tot_err[0][1]}")
# else:
# fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!")

File diff suppressed because it is too large Load Diff