mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
fix style
This commit is contained in:
parent
90dac69bc5
commit
e32bda6a7b
@ -45,33 +45,41 @@ import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Callable
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import k2
|
||||
from lhotse import load_manifest_lazy
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from transformers import BertTokenizer, BertModel
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_with_context,
|
||||
greedy_search_batch,
|
||||
greedy_search_batch_with_context,
|
||||
greedy_search_with_context,
|
||||
modified_beam_search,
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from utils import get_facebook_biasing_list
|
||||
from text_normalization import train_text_normalization, ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
|
||||
from lhotse import load_manifest_lazy
|
||||
from text_normalization import (
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from train_bert_encoder_with_style import (
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
)
|
||||
from transformers import BertModel, BertTokenizer
|
||||
from utils import get_facebook_biasing_list
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -89,11 +97,13 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
@ -144,7 +154,7 @@ def get_parser():
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
@ -260,21 +270,20 @@ def get_parser():
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--input-manifest",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input manifest to be decoded"
|
||||
help="The input manifest to be decoded",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--output-manifest",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the output manifest (directory)"
|
||||
help="Where to store the output manifest (directory)",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
@ -282,19 +291,19 @@ def get_parser():
|
||||
default=True,
|
||||
help="Use pre-text is available during decoding",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
help="Use style prompt when evaluation",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-context-embedding",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use context fuser when evaluation"
|
||||
help="Use context fuser when evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -310,43 +319,44 @@ def get_parser():
|
||||
default=True,
|
||||
help="Reports CER. By default, only reports WER",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
help="The style of style prompt, i.e style_text",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
help="The style of content prompt, i.e pre_text",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-test-set",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use librispeech test set for evaluation."
|
||||
help="Use librispeech test set for evaluation.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-context-list",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If use a fixed context list for LibriSpeech decoding"
|
||||
help="If use a fixed context list for LibriSpeech decoding",
|
||||
)
|
||||
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
@ -366,7 +376,7 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
@ -421,37 +431,43 @@ def decode_one_batch(
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
cut_ids = [c.supervisions[0].id for c in cuts]
|
||||
batch_size = feature.size(0)
|
||||
|
||||
|
||||
# get pre_text
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"]["text"] # use the ground truth ref text as pre_text
|
||||
pre_texts = batch["supervisions"][
|
||||
"text"
|
||||
] # use the ground truth ref text as pre_text
|
||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
|
||||
if params.use_ls_context_list:
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
|
||||
|
||||
# get style_text
|
||||
if params.use_style_prompt:
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
||||
style_texts = batch["supervisions"].get(
|
||||
"style_text", [fixed_sentence for _ in range(batch_size)]
|
||||
)
|
||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||
else:
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
|
||||
# Get the text embedding input
|
||||
if params.use_pre_text or params.use_style_prompt:
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
# pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
|
||||
style_texts = _apply_style_transform(
|
||||
style_texts, params.style_text_transform
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
|
||||
# Use tokenizer to prepare input for text encoder
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_texts,
|
||||
@ -459,11 +475,11 @@ def decode_one_batch(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
@ -487,10 +503,7 @@ def decode_one_batch(
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
if memory is None or not params.use_context_embedding:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
@ -498,9 +511,11 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
else:
|
||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
||||
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
|
||||
context = model.joiner.context_proj(context) # (N,C)
|
||||
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
|
||||
context = model.context_fuser(
|
||||
memory, padding_mask=memory_key_padding_mask
|
||||
) # (N,C)
|
||||
context = model.joiner.context_proj(context) # (N,C)
|
||||
hyp_tokens = greedy_search_batch_with_context(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -533,19 +548,13 @@ def decode_one_batch(
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
else:
|
||||
cur_context = context[i:i+1, :]
|
||||
cur_context = context[i : i + 1, :]
|
||||
hyp = greedy_search_with_context(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
context=cur_context,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -608,13 +617,15 @@ def decode_dataset(
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
|
||||
|
||||
texts = batch["supervisions"][
|
||||
"text"
|
||||
] # By default, this should be in mixed-punc format
|
||||
|
||||
# the style of ref_text should match style_text
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
if params.use_style_prompt:
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
@ -645,9 +656,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -677,7 +686,9 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
@ -695,9 +706,7 @@ def save_results(
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -712,9 +721,7 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
@ -740,65 +747,69 @@ def add_decoding_result_to_manifest(
|
||||
for items in value:
|
||||
id, ref, hyp = items
|
||||
new_ans[id] = " ".join(hyp)
|
||||
|
||||
def _add_decoding(c):
|
||||
key = c.supervisions[0].id
|
||||
c.supervisions[0].texts.append(new_ans[key])
|
||||
return c
|
||||
|
||||
in_manifest = in_manifest.map(_add_decoding)
|
||||
logging.info(f"Saving manifest to {out_manifest}")
|
||||
in_manifest.to_file(out_manifest)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriHeavyAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
|
||||
cuts = load_manifest_lazy(args.input_manifest)
|
||||
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
splitted_cuts = cuts.split(num_splits=world_size)
|
||||
mp.spawn(run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True)
|
||||
mp.spawn(
|
||||
run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True
|
||||
)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args, cuts=cuts)
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run(rank, world_size, args, cuts):
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
|
||||
if params.use_pre_text:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}"
|
||||
|
||||
|
||||
if params.use_style_prompt:
|
||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||
|
||||
|
||||
params.suffix += f"-{rank}"
|
||||
|
||||
world_size = params.world_size
|
||||
|
||||
|
||||
params.output_manifest = Path(params.output_manifest)
|
||||
if world_size > 1:
|
||||
cuts = cuts[rank]
|
||||
out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz"
|
||||
else:
|
||||
out_name = params.output_manifest / f"with_decoding.jsonl.gz"
|
||||
|
||||
out_name = params.output_manifest / "with_decoding.jsonl.gz"
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
|
||||
|
||||
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
@ -819,9 +830,9 @@ def run(rank, world_size, args, cuts):
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -848,9 +859,9 @@ def run(rank, world_size, args, cuts):
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -909,14 +920,16 @@ def run(rank, world_size, args, cuts):
|
||||
args.return_cuts = True
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
dl = libriheavy.valid_dataloaders(cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
|
||||
dl = libriheavy.valid_dataloaders(
|
||||
cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dl = [dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
biasing_dict = None
|
||||
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
@ -933,7 +946,7 @@ def run(rank, world_size, args, cuts):
|
||||
# test_set_name=test_set,
|
||||
# results_dict=results_dict,
|
||||
# )
|
||||
|
||||
|
||||
add_decoding_result_to_manifest(
|
||||
in_manifest=cuts,
|
||||
out_manifest=out_name,
|
||||
@ -942,6 +955,7 @@ def run(rank, world_size, args, cuts):
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user