fix style

This commit is contained in:
marcoyang 2023-10-10 16:55:31 +08:00
parent 90dac69bc5
commit e32bda6a7b

View File

@ -45,33 +45,41 @@ import math
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Callable from typing import Callable, Dict, List, Optional, Tuple
import torch.multiprocessing as mp
import k2 import k2
from lhotse import load_manifest_lazy
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from transformers import BertTokenizer, BertModel
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import ( from beam_search import (
greedy_search, greedy_search,
greedy_search_with_context,
greedy_search_batch, greedy_search_batch,
greedy_search_batch_with_context, greedy_search_batch_with_context,
greedy_search_with_context,
modified_beam_search, modified_beam_search,
) )
from dataset import naive_triplet_text_sampling, random_shuffle_subset from dataset import naive_triplet_text_sampling, random_shuffle_subset
from utils import get_facebook_biasing_list from lhotse import load_manifest_lazy
from text_normalization import train_text_normalization, ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha from text_normalization import (
lower_all_char,
lower_only_alpha,
ref_text_normalization,
remove_non_alphabetic,
train_text_normalization,
upper_all_char,
upper_only_alpha,
)
from train_bert_encoder_with_style import ( from train_bert_encoder_with_style import (
_encode_texts_as_bytes_with_tokenizer,
add_model_arguments, add_model_arguments,
get_params, get_params,
get_tokenizer, get_tokenizer,
get_transducer_model, get_transducer_model,
_encode_texts_as_bytes_with_tokenizer,
) )
from transformers import BertModel, BertTokenizer
from utils import get_facebook_biasing_list
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -89,11 +97,13 @@ from icefall.utils import (
) )
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument( parser.add_argument(
"--world-size", "--world-size",
type=int, type=int,
@ -144,7 +154,7 @@ def get_parser():
default="pruned_transducer_stateless7/exp", default="pruned_transducer_stateless7/exp",
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument( parser.add_argument(
"--log-dir", "--log-dir",
type=str, type=str,
@ -260,21 +270,20 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest, Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--input-manifest", "--input-manifest",
type=str, type=str,
required=True, required=True,
help="The input manifest to be decoded" help="The input manifest to be decoded",
) )
parser.add_argument( parser.add_argument(
"--output-manifest", "--output-manifest",
type=str, type=str,
required=True, required=True,
help="Where to store the output manifest (directory)" help="Where to store the output manifest (directory)",
) )
parser.add_argument( parser.add_argument(
"--use-pre-text", "--use-pre-text",
@ -282,19 +291,19 @@ def get_parser():
default=True, default=True,
help="Use pre-text is available during decoding", help="Use pre-text is available during decoding",
) )
parser.add_argument( parser.add_argument(
"--use-style-prompt", "--use-style-prompt",
type=str2bool, type=str2bool,
default=True, default=True,
help="Use style prompt when evaluation" help="Use style prompt when evaluation",
) )
parser.add_argument( parser.add_argument(
"--use-context-embedding", "--use-context-embedding",
type=str2bool, type=str2bool,
default=False, default=False,
help="Use context fuser when evaluation" help="Use context fuser when evaluation",
) )
parser.add_argument( parser.add_argument(
@ -310,43 +319,44 @@ def get_parser():
default=True, default=True,
help="Reports CER. By default, only reports WER", help="Reports CER. By default, only reports WER",
) )
parser.add_argument( parser.add_argument(
"--style-text-transform", "--style-text-transform",
type=str, type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc", default="mixed-punc",
help="The style of style prompt, i.e style_text" help="The style of style prompt, i.e style_text",
) )
parser.add_argument( parser.add_argument(
"--pre-text-transform", "--pre-text-transform",
type=str, type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc", default="mixed-punc",
help="The style of content prompt, i.e pre_text" help="The style of content prompt, i.e pre_text",
) )
parser.add_argument( parser.add_argument(
"--use-ls-test-set", "--use-ls-test-set",
type=str2bool, type=str2bool,
default=False, default=False,
help="Use librispeech test set for evaluation." help="Use librispeech test set for evaluation.",
) )
parser.add_argument( parser.add_argument(
"--use-ls-context-list", "--use-ls-context-list",
type=str2bool, type=str2bool,
default=False, default=False,
help="If use a fixed context list for LibriSpeech decoding" help="If use a fixed context list for LibriSpeech decoding",
) )
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]: def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in """Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc. ground truth format, i.e mixed-punc.
Args: Args:
@ -366,7 +376,7 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]:
return [lower_all_char(s) for s in text] return [lower_all_char(s) for s in text]
else: else:
raise NotImplementedError(f"Unseen transform: {transform}") raise NotImplementedError(f"Unseen transform: {transform}")
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
@ -421,37 +431,43 @@ def decode_one_batch(
cuts = batch["supervisions"]["cut"] cuts = batch["supervisions"]["cut"]
cut_ids = [c.supervisions[0].id for c in cuts] cut_ids = [c.supervisions[0].id for c in cuts]
batch_size = feature.size(0) batch_size = feature.size(0)
# get pre_text # get pre_text
if "pre_text" in batch["supervisions"] and params.use_pre_text: if "pre_text" in batch["supervisions"] and params.use_pre_text:
pre_texts = batch["supervisions"]["text"] # use the ground truth ref text as pre_text pre_texts = batch["supervisions"][
"text"
] # use the ground truth ref text as pre_text
pre_texts = [train_text_normalization(t) for t in pre_texts] pre_texts = [train_text_normalization(t) for t in pre_texts]
else: else:
pre_texts = ["" for _ in range(batch_size)] pre_texts = ["" for _ in range(batch_size)]
if params.use_ls_context_list: if params.use_ls_context_list:
pre_texts = [biasing_dict[id] for id in cut_ids] pre_texts = [biasing_dict[id] for id in cut_ids]
# get style_text # get style_text
if params.use_style_prompt: if params.use_style_prompt:
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)]) style_texts = batch["supervisions"].get(
"style_text", [fixed_sentence for _ in range(batch_size)]
)
style_texts = [train_text_normalization(t) for t in style_texts] style_texts = [train_text_normalization(t) for t in style_texts]
else: else:
style_texts = ["" for _ in range(batch_size)] # use empty string style_texts = ["" for _ in range(batch_size)] # use empty string
# Get the text embedding input # Get the text embedding input
if params.use_pre_text or params.use_style_prompt: if params.use_pre_text or params.use_style_prompt:
# apply style transform to the pre_text and style_text # apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) # pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
if params.use_style_prompt: if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform) style_texts = _apply_style_transform(
style_texts, params.style_text_transform
)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
# Use tokenizer to prepare input for text encoder # Use tokenizer to prepare input for text encoder
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
pre_texts=pre_texts, pre_texts=pre_texts,
@ -459,11 +475,11 @@ def decode_one_batch(
tokenizer=tokenizer, tokenizer=tokenizer,
device=device, device=device,
) )
memory, memory_key_padding_mask = model.encode_text( memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs, encoded_inputs=encoded_inputs,
style_lens=style_lens, style_lens=style_lens,
) # (T,B,C) ) # (T,B,C)
else: else:
memory = None memory = None
memory_key_padding_mask = None memory_key_padding_mask = None
@ -487,10 +503,7 @@ def decode_one_batch(
hyps = [] hyps = []
if ( if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
if memory is None or not params.use_context_embedding: if memory is None or not params.use_context_embedding:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
@ -498,9 +511,11 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
else: else:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C) memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C) context = model.context_fuser(
context = model.joiner.context_proj(context) # (N,C) memory, padding_mask=memory_key_padding_mask
) # (N,C)
context = model.joiner.context_proj(context) # (N,C)
hyp_tokens = greedy_search_batch_with_context( hyp_tokens = greedy_search_batch_with_context(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -533,19 +548,13 @@ def decode_one_batch(
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
) )
else: else:
cur_context = context[i:i+1, :] cur_context = context[i : i + 1, :]
hyp = greedy_search_with_context( hyp = greedy_search_with_context(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
context=cur_context, context=cur_context,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
) )
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -608,13 +617,15 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format texts = batch["supervisions"][
"text"
] # By default, this should be in mixed-punc format
# the style of ref_text should match style_text # the style of ref_text should match style_text
texts = _apply_style_transform(texts, params.style_text_transform) texts = _apply_style_transform(texts, params.style_text_transform)
if params.use_style_prompt: if params.use_style_prompt:
texts = _apply_style_transform(texts, params.style_text_transform) texts = _apply_style_transform(texts, params.style_text_transform)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -645,9 +656,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -677,7 +686,9 @@ def save_results(
if params.compute_CER: if params.compute_CER:
# Write CER statistics # Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" recog_path = (
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results, char_level=True) store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = ( errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
@ -695,9 +706,7 @@ def save_results(
logging.info("Wrote detailed CER stats to {}".format(errs_filename)) logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -712,9 +721,7 @@ def save_results(
if params.compute_CER: if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
for key, val in test_set_cers: for key, val in test_set_cers:
@ -740,65 +747,69 @@ def add_decoding_result_to_manifest(
for items in value: for items in value:
id, ref, hyp = items id, ref, hyp = items
new_ans[id] = " ".join(hyp) new_ans[id] = " ".join(hyp)
def _add_decoding(c): def _add_decoding(c):
key = c.supervisions[0].id key = c.supervisions[0].id
c.supervisions[0].texts.append(new_ans[key]) c.supervisions[0].texts.append(new_ans[key])
return c return c
in_manifest = in_manifest.map(_add_decoding) in_manifest = in_manifest.map(_add_decoding)
logging.info(f"Saving manifest to {out_manifest}") logging.info(f"Saving manifest to {out_manifest}")
in_manifest.to_file(out_manifest) in_manifest.to_file(out_manifest)
def main(): def main():
parser = get_parser() parser = get_parser()
LibriHeavyAsrDataModule.add_arguments(parser) LibriHeavyAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
cuts = load_manifest_lazy(args.input_manifest) cuts = load_manifest_lazy(args.input_manifest)
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: if world_size > 1:
splitted_cuts = cuts.split(num_splits=world_size) splitted_cuts = cuts.split(num_splits=world_size)
mp.spawn(run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True) mp.spawn(
run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True
)
else: else:
run(rank=0, world_size=1, args=args, cuts=cuts) run(rank=0, world_size=1, args=args, cuts=cuts)
@torch.no_grad() @torch.no_grad()
def run(rank, world_size, args, cuts): def run(rank, world_size, args, cuts):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_pre_text: if params.use_pre_text:
params.suffix += f"-pre-text-{params.pre_text_transform}" params.suffix += f"-pre-text-{params.pre_text_transform}"
if params.use_style_prompt: if params.use_style_prompt:
params.suffix += f"-style-prompt-{params.style_text_transform}" params.suffix += f"-style-prompt-{params.style_text_transform}"
params.suffix += f"-{rank}" params.suffix += f"-{rank}"
world_size = params.world_size world_size = params.world_size
params.output_manifest = Path(params.output_manifest) params.output_manifest = Path(params.output_manifest)
if world_size > 1: if world_size > 1:
cuts = cuts[rank] cuts = cuts[rank]
out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz" out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz"
else: else:
out_name = params.output_manifest / f"with_decoding.jsonl.gz" out_name = params.output_manifest / "with_decoding.jsonl.gz"
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}") setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
logging.info("Decoding started") logging.info("Decoding started")
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
@ -819,9 +830,9 @@ def run(rank, world_size, args, cuts):
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -848,9 +859,9 @@ def run(rank, world_size, args, cuts):
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -909,14 +920,16 @@ def run(rank, world_size, args, cuts):
args.return_cuts = True args.return_cuts = True
libriheavy = LibriHeavyAsrDataModule(args) libriheavy = LibriHeavyAsrDataModule(args)
dl = libriheavy.valid_dataloaders(cuts, text_sampling_func=naive_triplet_text_sampling) dl = libriheavy.valid_dataloaders(
cuts, text_sampling_func=naive_triplet_text_sampling
)
test_sets = ["test"] test_sets = ["test"]
test_dl = [dl] test_dl = [dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
biasing_dict = None biasing_dict = None
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
@ -933,7 +946,7 @@ def run(rank, world_size, args, cuts):
# test_set_name=test_set, # test_set_name=test_set,
# results_dict=results_dict, # results_dict=results_dict,
# ) # )
add_decoding_result_to_manifest( add_decoding_result_to_manifest(
in_manifest=cuts, in_manifest=cuts,
out_manifest=out_name, out_manifest=out_name,
@ -942,6 +955,7 @@ def run(rank, world_size, args, cuts):
logging.info("Done!") logging.info("Done!")
# torch.set_num_threads(1) # torch.set_num_threads(1)
# torch.set_num_interop_threads(1) # torch.set_num_interop_threads(1)