mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
fix style
This commit is contained in:
parent
90dac69bc5
commit
e32bda6a7b
@ -45,33 +45,41 @@ import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Callable
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import k2
|
||||
from lhotse import load_manifest_lazy
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from transformers import BertTokenizer, BertModel
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_with_context,
|
||||
greedy_search_batch,
|
||||
greedy_search_batch_with_context,
|
||||
greedy_search_with_context,
|
||||
modified_beam_search,
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from utils import get_facebook_biasing_list
|
||||
from text_normalization import train_text_normalization, ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
|
||||
from lhotse import load_manifest_lazy
|
||||
from text_normalization import (
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from train_bert_encoder_with_style import (
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
)
|
||||
from transformers import BertModel, BertTokenizer
|
||||
from utils import get_facebook_biasing_list
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -89,6 +97,8 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -265,17 +275,16 @@ def get_parser():
|
||||
"--input-manifest",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input manifest to be decoded"
|
||||
help="The input manifest to be decoded",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-manifest",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the output manifest (directory)"
|
||||
help="Where to store the output manifest (directory)",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
@ -287,14 +296,14 @@ def get_parser():
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
help="Use style prompt when evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-context-embedding",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use context fuser when evaluation"
|
||||
help="Use context fuser when evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -314,37 +323,38 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
help="The style of style prompt, i.e style_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
help="The style of content prompt, i.e pre_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-test-set",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use librispeech test set for evaluation."
|
||||
help="Use librispeech test set for evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-context-list",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If use a fixed context list for LibriSpeech decoding"
|
||||
help="If use a fixed context list for LibriSpeech decoding",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
@ -424,7 +434,9 @@ def decode_one_batch(
|
||||
|
||||
# get pre_text
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"]["text"] # use the ground truth ref text as pre_text
|
||||
pre_texts = batch["supervisions"][
|
||||
"text"
|
||||
] # use the ground truth ref text as pre_text
|
||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
@ -435,7 +447,9 @@ def decode_one_batch(
|
||||
# get style_text
|
||||
if params.use_style_prompt:
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
||||
style_texts = batch["supervisions"].get(
|
||||
"style_text", [fixed_sentence for _ in range(batch_size)]
|
||||
)
|
||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||
else:
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
@ -445,9 +459,11 @@ def decode_one_batch(
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
# pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
style_texts = _apply_style_transform(
|
||||
style_texts, params.style_text_transform
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
@ -487,10 +503,7 @@ def decode_one_batch(
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
if memory is None or not params.use_context_embedding:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
@ -498,8 +511,10 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
else:
|
||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
||||
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
|
||||
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
|
||||
context = model.context_fuser(
|
||||
memory, padding_mask=memory_key_padding_mask
|
||||
) # (N,C)
|
||||
context = model.joiner.context_proj(context) # (N,C)
|
||||
hyp_tokens = greedy_search_batch_with_context(
|
||||
model=model,
|
||||
@ -533,19 +548,13 @@ def decode_one_batch(
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
else:
|
||||
cur_context = context[i:i+1, :]
|
||||
cur_context = context[i : i + 1, :]
|
||||
hyp = greedy_search_with_context(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
context=cur_context,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -608,7 +617,9 @@ def decode_dataset(
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
|
||||
texts = batch["supervisions"][
|
||||
"text"
|
||||
] # By default, this should be in mixed-punc format
|
||||
|
||||
# the style of ref_text should match style_text
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
@ -645,9 +656,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -677,7 +686,9 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
@ -695,9 +706,7 @@ def save_results(
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -712,9 +721,7 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
@ -740,10 +747,12 @@ def add_decoding_result_to_manifest(
|
||||
for items in value:
|
||||
id, ref, hyp = items
|
||||
new_ans[id] = " ".join(hyp)
|
||||
|
||||
def _add_decoding(c):
|
||||
key = c.supervisions[0].id
|
||||
c.supervisions[0].texts.append(new_ans[key])
|
||||
return c
|
||||
|
||||
in_manifest = in_manifest.map(_add_decoding)
|
||||
logging.info(f"Saving manifest to {out_manifest}")
|
||||
in_manifest.to_file(out_manifest)
|
||||
@ -761,7 +770,9 @@ def main():
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
splitted_cuts = cuts.split(num_splits=world_size)
|
||||
mp.spawn(run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True)
|
||||
mp.spawn(
|
||||
run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True
|
||||
)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args, cuts=cuts)
|
||||
|
||||
@ -792,7 +803,7 @@ def run(rank, world_size, args, cuts):
|
||||
cuts = cuts[rank]
|
||||
out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz"
|
||||
else:
|
||||
out_name = params.output_manifest / f"with_decoding.jsonl.gz"
|
||||
out_name = params.output_manifest / "with_decoding.jsonl.gz"
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
@ -819,9 +830,9 @@ def run(rank, world_size, args, cuts):
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -848,9 +859,9 @@ def run(rank, world_size, args, cuts):
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -909,7 +920,9 @@ def run(rank, world_size, args, cuts):
|
||||
args.return_cuts = True
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
dl = libriheavy.valid_dataloaders(cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
dl = libriheavy.valid_dataloaders(
|
||||
cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dl = [dl]
|
||||
@ -942,6 +955,7 @@ def run(rank, world_size, args, cuts):
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user