fix style

This commit is contained in:
marcoyang 2023-10-10 16:55:31 +08:00
parent 90dac69bc5
commit e32bda6a7b

View File

@ -45,33 +45,41 @@ import math
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Callable
from typing import Callable, Dict, List, Optional, Tuple
import torch.multiprocessing as mp
import k2
from lhotse import load_manifest_lazy
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import (
greedy_search,
greedy_search_with_context,
greedy_search_batch,
greedy_search_batch_with_context,
greedy_search_with_context,
modified_beam_search,
)
from dataset import naive_triplet_text_sampling, random_shuffle_subset
from utils import get_facebook_biasing_list
from text_normalization import train_text_normalization, ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
from lhotse import load_manifest_lazy
from text_normalization import (
lower_all_char,
lower_only_alpha,
ref_text_normalization,
remove_non_alphabetic,
train_text_normalization,
upper_all_char,
upper_only_alpha,
)
from train_bert_encoder_with_style import (
_encode_texts_as_bytes_with_tokenizer,
add_model_arguments,
get_params,
get_tokenizer,
get_transducer_model,
_encode_texts_as_bytes_with_tokenizer,
)
from transformers import BertModel, BertTokenizer
from utils import get_facebook_biasing_list
from icefall.checkpoint import (
average_checkpoints,
@ -89,11 +97,13 @@ from icefall.utils import (
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
@ -144,7 +154,7 @@ def get_parser():
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--log-dir",
type=str,
@ -260,21 +270,20 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--input-manifest",
type=str,
required=True,
help="The input manifest to be decoded"
help="The input manifest to be decoded",
)
parser.add_argument(
"--output-manifest",
type=str,
required=True,
help="Where to store the output manifest (directory)"
help="Where to store the output manifest (directory)",
)
parser.add_argument(
"--use-pre-text",
@ -282,19 +291,19 @@ def get_parser():
default=True,
help="Use pre-text is available during decoding",
)
parser.add_argument(
"--use-style-prompt",
type=str2bool,
default=True,
help="Use style prompt when evaluation"
help="Use style prompt when evaluation",
)
parser.add_argument(
"--use-context-embedding",
type=str2bool,
default=False,
help="Use context fuser when evaluation"
help="Use context fuser when evaluation",
)
parser.add_argument(
@ -310,43 +319,44 @@ def get_parser():
default=True,
help="Reports CER. By default, only reports WER",
)
parser.add_argument(
"--style-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc",
help="The style of style prompt, i.e style_text"
help="The style of style prompt, i.e style_text",
)
parser.add_argument(
"--pre-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc",
help="The style of content prompt, i.e pre_text"
help="The style of content prompt, i.e pre_text",
)
parser.add_argument(
"--use-ls-test-set",
type=str2bool,
default=False,
help="Use librispeech test set for evaluation."
help="Use librispeech test set for evaluation.",
)
parser.add_argument(
"--use-ls-context-list",
type=str2bool,
default=False,
help="If use a fixed context list for LibriSpeech decoding"
help="If use a fixed context list for LibriSpeech decoding",
)
add_model_arguments(parser)
return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
@ -366,7 +376,7 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]:
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
def decode_one_batch(
params: AttributeDict,
@ -421,37 +431,43 @@ def decode_one_batch(
cuts = batch["supervisions"]["cut"]
cut_ids = [c.supervisions[0].id for c in cuts]
batch_size = feature.size(0)
# get pre_text
if "pre_text" in batch["supervisions"] and params.use_pre_text:
pre_texts = batch["supervisions"]["text"] # use the ground truth ref text as pre_text
pre_texts = batch["supervisions"][
"text"
] # use the ground truth ref text as pre_text
pre_texts = [train_text_normalization(t) for t in pre_texts]
else:
pre_texts = ["" for _ in range(batch_size)]
if params.use_ls_context_list:
pre_texts = [biasing_dict[id] for id in cut_ids]
# get style_text
if params.use_style_prompt:
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
style_texts = batch["supervisions"].get(
"style_text", [fixed_sentence for _ in range(batch_size)]
)
style_texts = [train_text_normalization(t) for t in style_texts]
else:
style_texts = ["" for _ in range(batch_size)] # use empty string
style_texts = ["" for _ in range(batch_size)] # use empty string
# Get the text embedding input
if params.use_pre_text or params.use_style_prompt:
# apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
# pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
style_texts = _apply_style_transform(
style_texts, params.style_text_transform
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Use tokenizer to prepare input for text encoder
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
pre_texts=pre_texts,
@ -459,11 +475,11 @@ def decode_one_batch(
tokenizer=tokenizer,
device=device,
)
memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs,
style_lens=style_lens,
) # (T,B,C)
) # (T,B,C)
else:
memory = None
memory_key_padding_mask = None
@ -487,10 +503,7 @@ def decode_one_batch(
hyps = []
if (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
if memory is None or not params.use_context_embedding:
hyp_tokens = greedy_search_batch(
model=model,
@ -498,9 +511,11 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
)
else:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
context = model.joiner.context_proj(context) # (N,C)
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
context = model.context_fuser(
memory, padding_mask=memory_key_padding_mask
) # (N,C)
context = model.joiner.context_proj(context) # (N,C)
hyp_tokens = greedy_search_batch_with_context(
model=model,
encoder_out=encoder_out,
@ -533,19 +548,13 @@ def decode_one_batch(
max_sym_per_frame=params.max_sym_per_frame,
)
else:
cur_context = context[i:i+1, :]
cur_context = context[i : i + 1, :]
hyp = greedy_search_with_context(
model=model,
encoder_out=encoder_out_i,
context=cur_context,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -608,13 +617,15 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
texts = batch["supervisions"][
"text"
] # By default, this should be in mixed-punc format
# the style of ref_text should match style_text
texts = _apply_style_transform(texts, params.style_text_transform)
texts = _apply_style_transform(texts, params.style_text_transform)
if params.use_style_prompt:
texts = _apply_style_transform(texts, params.style_text_transform)
texts = _apply_style_transform(texts, params.style_text_transform)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@ -645,9 +656,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -677,7 +686,9 @@ def save_results(
if params.compute_CER:
# Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
recog_path = (
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
@ -695,9 +706,7 @@ def save_results(
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@ -712,9 +721,7 @@ def save_results(
if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_cers:
@ -740,65 +747,69 @@ def add_decoding_result_to_manifest(
for items in value:
id, ref, hyp = items
new_ans[id] = " ".join(hyp)
def _add_decoding(c):
key = c.supervisions[0].id
c.supervisions[0].texts.append(new_ans[key])
return c
in_manifest = in_manifest.map(_add_decoding)
logging.info(f"Saving manifest to {out_manifest}")
in_manifest.to_file(out_manifest)
def main():
parser = get_parser()
LibriHeavyAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
cuts = load_manifest_lazy(args.input_manifest)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
splitted_cuts = cuts.split(num_splits=world_size)
mp.spawn(run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True)
mp.spawn(
run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True
)
else:
run(rank=0, world_size=1, args=args, cuts=cuts)
@torch.no_grad()
def run(rank, world_size, args, cuts):
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_pre_text:
params.suffix += f"-pre-text-{params.pre_text_transform}"
if params.use_style_prompt:
params.suffix += f"-style-prompt-{params.style_text_transform}"
params.suffix += f"-{rank}"
world_size = params.world_size
params.output_manifest = Path(params.output_manifest)
if world_size > 1:
cuts = cuts[rank]
out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz"
else:
out_name = params.output_manifest / f"with_decoding.jsonl.gz"
out_name = params.output_manifest / "with_decoding.jsonl.gz"
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
logging.info("Decoding started")
logging.info(f"Device: {device}")
@ -819,9 +830,9 @@ def run(rank, world_size, args, cuts):
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -848,9 +859,9 @@ def run(rank, world_size, args, cuts):
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -909,14 +920,16 @@ def run(rank, world_size, args, cuts):
args.return_cuts = True
libriheavy = LibriHeavyAsrDataModule(args)
dl = libriheavy.valid_dataloaders(cuts, text_sampling_func=naive_triplet_text_sampling)
dl = libriheavy.valid_dataloaders(
cuts, text_sampling_func=naive_triplet_text_sampling
)
test_sets = ["test"]
test_dl = [dl]
for test_set, test_dl in zip(test_sets, test_dl):
biasing_dict = None
results_dict = decode_dataset(
dl=test_dl,
params=params,
@ -933,7 +946,7 @@ def run(rank, world_size, args, cuts):
# test_set_name=test_set,
# results_dict=results_dict,
# )
add_decoding_result_to_manifest(
in_manifest=cuts,
out_manifest=out_name,
@ -942,6 +955,7 @@ def run(rank, world_size, args, cuts):
logging.info("Done!")
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)