diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py index 6aaf81565..4559ebb6d 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py @@ -45,33 +45,41 @@ import math import warnings from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Callable +from typing import Callable, Dict, List, Optional, Tuple -import torch.multiprocessing as mp import k2 -from lhotse import load_manifest_lazy import sentencepiece as spm import torch +import torch.multiprocessing as mp import torch.nn as nn -from transformers import BertTokenizer, BertModel from asr_datamodule import LibriHeavyAsrDataModule from beam_search import ( greedy_search, - greedy_search_with_context, greedy_search_batch, greedy_search_batch_with_context, + greedy_search_with_context, modified_beam_search, ) from dataset import naive_triplet_text_sampling, random_shuffle_subset -from utils import get_facebook_biasing_list -from text_normalization import train_text_normalization, ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha +from lhotse import load_manifest_lazy +from text_normalization import ( + lower_all_char, + lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) from train_bert_encoder_with_style import ( + _encode_texts_as_bytes_with_tokenizer, add_model_arguments, get_params, get_tokenizer, get_transducer_model, - _encode_texts_as_bytes_with_tokenizer, ) +from transformers import BertModel, BertTokenizer +from utils import get_facebook_biasing_list from icefall.checkpoint import ( average_checkpoints, @@ -89,11 +97,13 @@ from icefall.utils import ( ) LOG_EPS = math.log(1e-10) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - + parser.add_argument( "--world-size", type=int, @@ -144,7 +154,7 @@ def get_parser(): default="pruned_transducer_stateless7/exp", help="The experiment dir", ) - + parser.add_argument( "--log-dir", type=str, @@ -260,21 +270,20 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - + parser.add_argument( "--input-manifest", type=str, required=True, - help="The input manifest to be decoded" + help="The input manifest to be decoded", ) - + parser.add_argument( "--output-manifest", type=str, required=True, - help="Where to store the output manifest (directory)" + help="Where to store the output manifest (directory)", ) - parser.add_argument( "--use-pre-text", @@ -282,19 +291,19 @@ def get_parser(): default=True, help="Use pre-text is available during decoding", ) - + parser.add_argument( "--use-style-prompt", type=str2bool, default=True, - help="Use style prompt when evaluation" + help="Use style prompt when evaluation", ) - + parser.add_argument( "--use-context-embedding", type=str2bool, default=False, - help="Use context fuser when evaluation" + help="Use context fuser when evaluation", ) parser.add_argument( @@ -310,43 +319,44 @@ def get_parser(): default=True, help="Reports CER. By default, only reports WER", ) - + parser.add_argument( "--style-text-transform", type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], default="mixed-punc", - help="The style of style prompt, i.e style_text" + help="The style of style prompt, i.e style_text", ) - + parser.add_argument( "--pre-text-transform", type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], default="mixed-punc", - help="The style of content prompt, i.e pre_text" + help="The style of content prompt, i.e pre_text", ) - + parser.add_argument( "--use-ls-test-set", type=str2bool, default=False, - help="Use librispeech test set for evaluation." + help="Use librispeech test set for evaluation.", ) - + parser.add_argument( "--use-ls-context-list", type=str2bool, default=False, - help="If use a fixed context list for LibriSpeech decoding" + help="If use a fixed context list for LibriSpeech decoding", ) - + add_model_arguments(parser) return parser + def _apply_style_transform(text: List[str], transform: str) -> List[str]: - """Apply transform to a list of text. By default, the text are in + """Apply transform to a list of text. By default, the text are in ground truth format, i.e mixed-punc. Args: @@ -366,7 +376,7 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]: return [lower_all_char(s) for s in text] else: raise NotImplementedError(f"Unseen transform: {transform}") - + def decode_one_batch( params: AttributeDict, @@ -421,37 +431,43 @@ def decode_one_batch( cuts = batch["supervisions"]["cut"] cut_ids = [c.supervisions[0].id for c in cuts] batch_size = feature.size(0) - + # get pre_text if "pre_text" in batch["supervisions"] and params.use_pre_text: - pre_texts = batch["supervisions"]["text"] # use the ground truth ref text as pre_text + pre_texts = batch["supervisions"][ + "text" + ] # use the ground truth ref text as pre_text pre_texts = [train_text_normalization(t) for t in pre_texts] else: pre_texts = ["" for _ in range(batch_size)] - + if params.use_ls_context_list: pre_texts = [biasing_dict[id] for id in cut_ids] - + # get style_text if params.use_style_prompt: fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." - style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)]) + style_texts = batch["supervisions"].get( + "style_text", [fixed_sentence for _ in range(batch_size)] + ) style_texts = [train_text_normalization(t) for t in style_texts] else: - style_texts = ["" for _ in range(batch_size)] # use empty string + style_texts = ["" for _ in range(batch_size)] # use empty string # Get the text embedding input if params.use_pre_text or params.use_style_prompt: # apply style transform to the pre_text and style_text pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) - #pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) + # pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) if params.use_style_prompt: - style_texts = _apply_style_transform(style_texts, params.style_text_transform) - + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + with warnings.catch_warnings(): warnings.simplefilter("ignore") - + # Use tokenizer to prepare input for text encoder encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( pre_texts=pre_texts, @@ -459,11 +475,11 @@ def decode_one_batch( tokenizer=tokenizer, device=device, ) - + memory, memory_key_padding_mask = model.encode_text( encoded_inputs=encoded_inputs, style_lens=style_lens, - ) # (T,B,C) + ) # (T,B,C) else: memory = None memory_key_padding_mask = None @@ -487,10 +503,7 @@ def decode_one_batch( hyps = [] - if ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: if memory is None or not params.use_context_embedding: hyp_tokens = greedy_search_batch( model=model, @@ -498,9 +511,11 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) else: - memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C) - context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C) - context = model.joiner.context_proj(context) # (N,C) + memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) + context = model.context_fuser( + memory, padding_mask=memory_key_padding_mask + ) # (N,C) + context = model.joiner.context_proj(context) # (N,C) hyp_tokens = greedy_search_batch_with_context( model=model, encoder_out=encoder_out, @@ -533,19 +548,13 @@ def decode_one_batch( max_sym_per_frame=params.max_sym_per_frame, ) else: - cur_context = context[i:i+1, :] + cur_context = context[i : i + 1, :] hyp = greedy_search_with_context( model=model, encoder_out=encoder_out_i, context=cur_context, max_sym_per_frame=params.max_sym_per_frame, ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -608,13 +617,15 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format - + texts = batch["supervisions"][ + "text" + ] # By default, this should be in mixed-punc format + # the style of ref_text should match style_text - texts = _apply_style_transform(texts, params.style_text_transform) + texts = _apply_style_transform(texts, params.style_text_transform) if params.use_style_prompt: - texts = _apply_style_transform(texts, params.style_text_transform) - + texts = _apply_style_transform(texts, params.style_text_transform) + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -645,9 +656,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -677,7 +686,9 @@ def save_results( if params.compute_CER: # Write CER statistics - recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) store_transcripts(filename=recog_path, texts=results, char_level=True) errs_filename = ( params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" @@ -695,9 +706,7 @@ def save_results( logging.info("Wrote detailed CER stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -712,9 +721,7 @@ def save_results( if params.compute_CER: test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_cers: @@ -740,65 +747,69 @@ def add_decoding_result_to_manifest( for items in value: id, ref, hyp = items new_ans[id] = " ".join(hyp) + def _add_decoding(c): key = c.supervisions[0].id c.supervisions[0].texts.append(new_ans[key]) return c + in_manifest = in_manifest.map(_add_decoding) logging.info(f"Saving manifest to {out_manifest}") in_manifest.to_file(out_manifest) - + def main(): parser = get_parser() LibriHeavyAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) - + cuts = load_manifest_lazy(args.input_manifest) - + world_size = args.world_size assert world_size >= 1 if world_size > 1: splitted_cuts = cuts.split(num_splits=world_size) - mp.spawn(run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True) + mp.spawn( + run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True + ) else: run(rank=0, world_size=1, args=args, cuts=cuts) - + @torch.no_grad() def run(rank, world_size, args, cuts): params = get_params() params.update(vars(args)) params.res_dir = params.exp_dir / params.decoding_method - + if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - + if params.use_pre_text: params.suffix += f"-pre-text-{params.pre_text_transform}" - + if params.use_style_prompt: params.suffix += f"-style-prompt-{params.style_text_transform}" - + params.suffix += f"-{rank}" world_size = params.world_size - + params.output_manifest = Path(params.output_manifest) if world_size > 1: cuts = cuts[rank] out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz" else: - out_name = params.output_manifest / f"with_decoding.jsonl.gz" - + out_name = params.output_manifest / "with_decoding.jsonl.gz" + device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) - - setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}") + + setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}") logging.info("Decoding started") logging.info(f"Device: {device}") @@ -819,9 +830,9 @@ def run(rank, world_size, args, cuts): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -848,9 +859,9 @@ def run(rank, world_size, args, cuts): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -909,14 +920,16 @@ def run(rank, world_size, args, cuts): args.return_cuts = True libriheavy = LibriHeavyAsrDataModule(args) - dl = libriheavy.valid_dataloaders(cuts, text_sampling_func=naive_triplet_text_sampling) - + dl = libriheavy.valid_dataloaders( + cuts, text_sampling_func=naive_triplet_text_sampling + ) + test_sets = ["test"] test_dl = [dl] for test_set, test_dl in zip(test_sets, test_dl): biasing_dict = None - + results_dict = decode_dataset( dl=test_dl, params=params, @@ -933,7 +946,7 @@ def run(rank, world_size, args, cuts): # test_set_name=test_set, # results_dict=results_dict, # ) - + add_decoding_result_to_manifest( in_manifest=cuts, out_manifest=out_name, @@ -942,6 +955,7 @@ def run(rank, world_size, args, cuts): logging.info("Done!") + # torch.set_num_threads(1) # torch.set_num_interop_threads(1)