From 657980072087eb8df0f078c6cb898b8e5e8421b5 Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Tue, 19 Sep 2023 18:38:56 +0800 Subject: [PATCH] update --- .../zipformer_prompt_asr/asr_datamodule.py | 4 - .../ASR/zipformer_prompt_asr/decode_bert.py | 368 +++++++++--------- .../train_bert_encoder.py | 353 +++++++++-------- 3 files changed, 368 insertions(+), 357 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 4b4c8a785..7a2a61a30 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -72,16 +72,12 @@ class LibriHeavyAsrDataModule: self.args = args if args.use_context_list: - from dataset2 import PromptASRDataset - assert args.rare_word_file is not None with open(args.rare_word_file, "r") as f: self.rare_word_list = ( f.read().lower().split() ) # Use lower-cased for easier style transform else: - from dataset import PromptASRDataset - self.rare_word_list = None @classmethod diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py index 0aa23d49a..18b3e9a14 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py @@ -20,22 +20,55 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless7/decode.py \ +./zipformer_prompt_asr/decode_bert.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ --decoding-method greedy_search (2) modified beam search -./pruned_transducer_stateless7/decode.py \ +./zipformer_prompt_asr/decode_bert.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --max-duration 600 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ --decoding-method modified_beam_search \ --beam-size 4 +(3) Decode LibriSpeech + +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --use-ls-test-set True \ + --beam-size 4 + +(4) Decode LibriSpeech + biasing list + +biasing_list=100 # could also be 0 + +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \\ + --use-ls-test-set True \ + --use-ls-context-list True \ + --biasing-level utterance \ + --ls-distractors $biasing_list \ + --post-normalization True \ + --text-encoder-type BERT \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc + + """ @@ -45,40 +78,34 @@ import math import warnings from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Callable +from typing import Callable, Dict, List, Optional, Tuple import k2 import sentencepiece as spm import torch import torch.nn as nn -from transformers import BertTokenizer, BertModel from asr_datamodule import LibriHeavyAsrDataModule -from beam_search import ( - greedy_search, - greedy_search_with_context, - greedy_search_batch, - greedy_search_batch_with_context, - modified_beam_search, -) +from beam_search import greedy_search, greedy_search_batch, modified_beam_search from dataset import naive_triplet_text_sampling, random_shuffle_subset -from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats from ls_text_normalization import word_normalization from text_normalization import ( - ref_text_normalization, - remove_non_alphabetic, - upper_only_alpha, - upper_all_char, lower_all_char, lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, train_text_normalization, + upper_all_char, + upper_only_alpha, ) -from train_bert_encoder_with_style import ( +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, add_model_arguments, get_params, get_tokenizer, get_transducer_model, - _encode_texts_as_bytes_with_tokenizer, ) +from transformers import BertModel, BertTokenizer +from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats from icefall.checkpoint import ( average_checkpoints, @@ -87,15 +114,11 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, -) +from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -262,26 +285,19 @@ def get_parser(): default=True, help="Use pre-text is available during decoding", ) - + parser.add_argument( "--use-style-prompt", type=str2bool, default=True, - help="Use style prompt when evaluation" + help="Use style prompt when evaluation", ) - + parser.add_argument( "--max-prompt-lens", type=int, default=1000, ) - - parser.add_argument( - "--use-context-embedding", - type=str2bool, - default=False, - help="Use context fuser when evaluation" - ) parser.add_argument( "--post-normalization", @@ -289,70 +305,65 @@ def get_parser(): default=True, help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", ) - - parser.add_argument( - "--long-audio-recog", - type=str2bool, - default=False, - ) parser.add_argument( "--compute-CER", type=str2bool, - default=True, + default=False, help="Reports CER. By default, only reports WER", ) - + parser.add_argument( "--style-text-transform", type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], default="mixed-punc", - help="The style of style prompt, i.e style_text" + help="The style of style prompt, i.e style_text", ) - + parser.add_argument( "--pre-text-transform", type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], default="mixed-punc", - help="The style of content prompt, i.e pre_text" + help="The style of content prompt, i.e pre_text", ) - + parser.add_argument( "--use-ls-test-set", type=str2bool, default=False, - help="Use librispeech test set for evaluation." + help="Use librispeech test set for evaluation.", ) - + parser.add_argument( "--use-ls-context-list", type=str2bool, default=False, - help="If use a fixed context list for LibriSpeech decoding" + help="If use a fixed context list for LibriSpeech decoding", ) - + parser.add_argument( "--biasing-level", type=str, default="utterance", choices=["utterance", "Book", "Chapter"], ) - + parser.add_argument( "--ls-distractors", type=int, default=0, - help="The number of distractors into context list for LibriSpeech decoding" + help="The number of distractors into context list for LibriSpeech decoding", ) - + add_model_arguments(parser) return parser + def _apply_style_transform(text: List[str], transform: str) -> List[str]: - """Apply transform to a list of text. By default, the text are in + """Apply transform to a list of text. By default, the text are in ground truth format, i.e mixed-punc. Args: @@ -372,13 +383,13 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]: return [lower_all_char(s) for s in text] else: raise NotImplementedError(f"Unseen transform: {transform}") - + def decode_one_batch( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - tokenizer, + tokenizer: spm.SentencePieceProcessor, batch: dict, biasing_dict: dict = None, word_table: Optional[k2.SymbolTable] = None, @@ -401,10 +412,15 @@ def decode_one_batch( The neural model. sp: The BPE model. + tokenizer: + Tokenizer for the text encoder batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + biasing_dict: + A dictionary in the form `{cut_id: :w1 w2"}` that contains a list + of biasing words (separated with space) word_table: The word symbol table. decoding_graph: @@ -427,48 +443,53 @@ def decode_one_batch( cuts = batch["supervisions"]["cut"] cut_ids = [c.supervisions[0].id for c in cuts] batch_size = feature.size(0) - - # get pre_text + if "pre_text" in batch["supervisions"] and params.use_pre_text: pre_texts = batch["supervisions"]["pre_text"] pre_texts = [train_text_normalization(t) for t in pre_texts] else: pre_texts = ["" for _ in range(batch_size)] - - if params.use_ls_context_list: + + if params.use_ls_context_list and params.use_ls_test_set: if params.biasing_level == "utterance": pre_texts = [biasing_dict[id] for id in cut_ids] elif params.biasing_level == "Chapter": - chapter_ids = [c.split('-')[1] for c in cut_ids] + chapter_ids = [c.split("-")[1] for c in cut_ids] pre_texts = [biasing_dict[id] for id in chapter_ids] elif params.biasing_level == "Book": - chapter_ids = [c.split('-')[1] for c in cut_ids] + chapter_ids = [c.split("-")[1] for c in cut_ids] pre_texts = [biasing_dict[id] for id in chapter_ids] + else: + raise ValueError(f"Unseen biasing level: {params.biasing_level}") if params.pre_text_transform == "mixed-punc": pre_texts = [t.lower() for t in pre_texts] - + # get style_text if params.use_style_prompt: fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." - style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)]) + style_texts = batch["supervisions"].get( + "style_text", [fixed_sentence for _ in range(batch_size)] + ) style_texts = [train_text_normalization(t) for t in style_texts] else: - style_texts = ["" for _ in range(batch_size)] # use empty string + style_texts = ["" for _ in range(batch_size)] # use empty string - # Get the text embedding input + # Get the text embedding if params.use_pre_text or params.use_style_prompt: # apply style transform to the pre_text and style_text pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) if not params.use_ls_context_list: - pre_texts = [t[:params.max_prompt_lens] for t in pre_texts] - #pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) + pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts] + if params.use_style_prompt: - style_texts = _apply_style_transform(style_texts, params.style_text_transform) - + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + with warnings.catch_warnings(): warnings.simplefilter("ignore") - + # Use tokenizer to prepare input for text encoder encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( pre_texts=pre_texts, @@ -477,12 +498,14 @@ def decode_one_batch( device=device, no_limit=True, ) - logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}") - + logging.info( + f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}" + ) + memory, memory_key_padding_mask = model.encode_text( encoded_inputs=encoded_inputs, style_lens=style_lens, - ) # (T,B,C) + ) # (T,B,C) else: memory = None memory_key_padding_mask = None @@ -506,26 +529,12 @@ def decode_one_batch( hyps = [] - if ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): - if memory is None or not params.use_context_embedding: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - else: - memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C) - context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C) - context = model.joiner.context_proj(context) # (N,C) - hyp_tokens = greedy_search_batch_with_context( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context=context, - ) + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": @@ -545,25 +554,10 @@ def decode_one_batch( encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - if memory is None or not params.use_context_embedding: - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - else: - cur_context = context[i:i+1, :] - hyp = greedy_search_with_context( - model=model, - encoder_out=encoder_out_i, - context=cur_context, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( + hyp = greedy_search( model=model, encoder_out=encoder_out_i, - beam=params.beam_size, + max_sym_per_frame=params.max_sym_per_frame, ) else: raise ValueError( @@ -582,7 +576,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - tokenizer, + tokenizer: spm.SentencePieceProcessor, biasing_dict: Dict = None, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, @@ -598,6 +592,11 @@ def decode_dataset( The neural model. sp: The BPE model. + tokenizer: + Tokenizer for the text encoder + biasing_dict: + A dictionary in the form `{cut_id: :w1 w2"}` that contains a list + of biasing words (separated with space) word_table: The word symbol table. decoding_graph: @@ -627,19 +626,25 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format - + texts = batch["supervisions"][ + "text" + ] # By default, this should be in mixed-punc format + # the style of ref_text should match style_text - texts = _apply_style_transform(texts, params.style_text_transform) + texts = _apply_style_transform(texts, params.style_text_transform) if params.use_style_prompt: - texts = _apply_style_transform(texts, params.style_text_transform) - + texts = _apply_style_transform(texts, params.style_text_transform) + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] if not params.use_ls_test_set: try: - book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] - except: - book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]] + book_names = [ + cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] + ] + except AttributeError: + book_names = [ + cut.id.split("/")[0] for cut in batch["supervisions"]["cut"] + ] else: book_names = ["" for _ in cut_ids] @@ -657,7 +662,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): + for cut_id, book_name, hyp_words, ref_text in zip( + cut_ids, book_names, hyps, texts + ): ref_text = ref_text_normalization( ref_text ) # remove full-width symbols & some book marks @@ -672,9 +679,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -705,7 +710,9 @@ def save_results( if params.compute_CER: # Write CER statistics - recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) store_transcripts(filename=recog_path, texts=results, char_level=True) errs_filename = ( params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" @@ -723,9 +730,7 @@ def save_results( logging.info("Wrote detailed CER stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -740,9 +745,7 @@ def save_results( if params.compute_CER: test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_cers: @@ -770,11 +773,8 @@ def main(): "greedy_search", "modified_beam_search", ) - - if params.long_audio_recog: - params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") - else: - params.res_dir = params.exp_dir / params.decoding_method + + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" @@ -792,22 +792,19 @@ def main(): params.suffix += f"-left-context-{params.left_context_frames}" if "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" if params.use_pre_text: - params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}" - + params.suffix += ( + f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}" + ) + if params.use_style_prompt: params.suffix += f"-style-prompt-{params.style_text_transform}" - - if params.use_context_embedding: - params.suffix += f"-use-context-fuser" - + if params.use_ls_context_list: params.suffix += f"-use-{params.biasing_level}-level-ls-context-list" if params.biasing_level == "utterance" and params.ls_distractors: @@ -841,9 +838,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -870,9 +867,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -935,18 +932,15 @@ def main(): test_other_cuts = libriheavy.test_other_cuts() ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() - long_audio_cuts = libriheavy.long_audio_cuts() - - npr1_dev_cuts = libriheavy.npr1_dev_cuts() - npr1_test_cuts = libriheavy.npr1_test_cuts() - test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling) - test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling) + test_clean_dl = libriheavy.valid_dataloaders( + test_clean_cuts, text_sampling_func=naive_triplet_text_sampling + ) + test_other_dl = libriheavy.valid_dataloaders( + test_other_cuts, text_sampling_func=naive_triplet_text_sampling + ) ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) - long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling) - npr1_dev_dl = libriheavy.valid_dataloaders(npr1_dev_cuts, text_sampling_func=naive_triplet_text_sampling) - npr1_test_dl = libriheavy.valid_dataloaders(npr1_test_cuts, text_sampling_func=naive_triplet_text_sampling) if params.use_ls_test_set: test_sets = ["ls-test-clean", "ls-test-other"] @@ -954,19 +948,21 @@ def main(): else: test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] - - if params.long_audio_recog: - test_sets = ["long-audio"] - test_dl = [long_audio_dl] for test_set, test_dl in zip(test_sets, test_dl): - if test_set == "ls-test-clean": - biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors) - elif test_set == "ls-test-other": - biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors) - else: - biasing_dict = None - + biasing_dict = None + if params.use_ls_context_list: + if test_set == "ls-test-clean": + biasing_dict = get_facebook_biasing_list( + test_set="test-clean", + num_distractors=params.ls_distractors, + ) + elif test_set == "ls-test-other": + biasing_dict = get_facebook_biasing_list( + test_set="test-other", + num_distractors=params.ls_distractors, + ) + results_dict = decode_dataset( dl=test_dl, params=params, @@ -983,35 +979,37 @@ def main(): test_set_name=test_set, results_dict=results_dict, ) - + if params.post_normalization: if "-post-normalization" not in params.suffix: params.suffix += "-post-normalization" - + new_res = {} for k in results_dict: new_ans = [] for item in results_dict[k]: id, ref, hyp = item if params.use_ls_test_set: - hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens + hyp = ( + " ".join(hyp).replace("-", " ").split() + ) # handle the hypens hyp = upper_only_alpha(" ".join(hyp)).split() hyp = [word_normalization(w.upper()) for w in hyp] hyp = " ".join(hyp).split() hyp = [w for w in hyp if w != ""] - ref = upper_only_alpha(" ".join(ref)).split() + ref = upper_only_alpha(" ".join(ref)).split() else: hyp = upper_only_alpha(" ".join(hyp)).split() - ref = upper_only_alpha(" ".join(ref)).split() - new_ans.append((id,ref,hyp)) + ref = upper_only_alpha(" ".join(ref)).split() + new_ans.append((id, ref, hyp)) new_res[k] = new_ans - + save_results( params=params, test_set_name=test_set, results_dict=new_res, ) - + if params.suffix.endswith("-post-normalization"): params.suffix = params.suffix.replace("-post-normalization", "") diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index de1b6ab85..bde640fb6 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -22,24 +22,43 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 300 - # For mix precision training: -./pruned_transducer_stateless7/train.py \ +(1) Non-streaming model, without context list + +./zipformer_prompt_asr/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 + --subset medium \ + --causal False \ + --exp-dir zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --memory-layer 0 \ + --memory-dim 768 \ + --text-encoder-type BERT \ + --use-style-prompt True \ + --use-context-list False + +(2) Non-streaming model, with context list + +./zipformer_prompt_asr/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --subset medium \ + --causal False \ + --exp-dir zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --memory-layer 0 \ + --memory-dim 768 \ + --text-encoder-type BERT \ + --use-style-prompt True \ + --use-context-list True \ + --rare-word-file data/context_biasing/small_rare_words_topk_10000.txt + """ @@ -61,30 +80,32 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriHeavyAsrDataModule -from dataset2 import ( - triplet_text_sampling, - triplet_text_sampling_with_context_list, +from dataset import ( naive_triplet_text_sampling, random_shuffle_subset, - joint_triplet_text_sampling, - triplet_style_text_sampling, + triplet_text_sampling, + triplet_text_sampling_with_context_list, ) -from dataset import multi_ref_text_triplet_text_sampling - from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model_with_BERT_with_style import PromptedTransducer +from model_with_BERT import PromptedTransducer from optim import Eden, ScaledAdam -from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR +from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR from subsampling import Conv2dSubsampling +from text_normalization import ( + lower_all_char, + lower_only_alpha, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from text_normalization import train_text_normalization, upper_only_alpha, lower_only_alpha, upper_all_char, lower_all_char from zipformer import Zipformer2 from icefall import diagnostics @@ -105,20 +126,20 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] style_transforms = [ - lambda x: x, # return it self + lambda x: x, # return it self upper_only_alpha, lower_only_alpha, - lower_all_char, + lower_all_char, ] + def random_sampling(texts: List[str]) -> str: return random.choice(texts) + def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str: # Randomly choose from the ground truth (mixed-cased trans) and the recog_text i = random.randint(0, 1) @@ -130,6 +151,7 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str: } return out + def get_first(texts: List[str], pre_texts: List[str]) -> str: out = { "text": texts[0], @@ -139,6 +161,7 @@ def get_first(texts: List[str], pre_texts: List[str]) -> str: } return out + def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str: # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha out = { @@ -149,6 +172,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str: } return out + def get_adjusted_batch_count(params: AttributeDict) -> float: # returns the number of batches we would have used so far if we had used the reference # duration. This is for purposes of set_batch_count(). @@ -205,19 +229,19 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="192,256,384,512,384,256", help="Embedding dimension in encoder stacks: a single int or comma-separated list.", ) - + parser.add_argument( "--memory-dropout-rate", type=float, default=0.05, - help="By which probability, dropout the memory when doing cross-attention." + help="By which probability, dropout the memory when doing cross-attention.", ) - + parser.add_argument( "--memory-layer", type=int, default=0, - help="Start doing cross-attention from which layer. Zero-indexed" + help="Start doing cross-attention from which layer. Zero-indexed", ) parser.add_argument( @@ -226,7 +250,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="32", help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) - + parser.add_argument( "--value-head-dim", type=str, @@ -280,13 +304,12 @@ def add_model_arguments(parser: argparse.ArgumentParser): to this dimension before adding. """, ) - + parser.add_argument( "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -312,29 +335,29 @@ def add_model_arguments(parser: argparse.ArgumentParser): "be converted to a number of chunks. If splitting into chunks, " "chunk left-context frames will be chosen randomly from this list; else not relevant.", ) - + parser.add_argument( "--text-encoder-type", type=str, default="BERT", - choices=["BERT","DistilBERT"], + choices=["BERT", "DistilBERT"], help="Type of the text encoder", ) - + parser.add_argument( "--text-encoder-adapter", type=str2bool, default=False, - help="An adapter for pre-trained BERT" + help="An adapter for pre-trained BERT", ) - + parser.add_argument( "--context-injection", type=str2bool, default=False, help="Inject context embedding into the joiner", ) - + parser.add_argument( "--context-dropout-rate", type=float, @@ -459,8 +482,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -537,14 +559,14 @@ def get_parser(): default=False, help="Whether to use half precision training.", ) - + parser.add_argument( "--use-style-prompt", type=str2bool, default=True, help="Whether to use style prompt.", ) - + # arguments for using prompt parser.add_argument( "--pre-text-shuffle-prob", @@ -552,14 +574,14 @@ def get_parser(): default=0.05, help="The proportion of pre_text to be shuffled with in a batch", ) - + parser.add_argument( "--style-text-shuffle-prob", type=float, default=0.2, help="The proportion of style_text to be shuffled with in a batch", ) - + parser.add_argument( "--prompt-mask-prob", type=float, @@ -571,14 +593,14 @@ def get_parser(): type=str2bool, default=True, ) - + parser.add_argument( "--forced-upper-pre-text", type=str2bool, default=False, help="Forced format of pre-text", ) - + add_model_arguments(parser) return parser @@ -674,25 +696,25 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: class TextEmbedding(nn.Module): def __init__( self, - num_embeddings: int=256, - embedding_dim: int=256, - kernel_size: int=3, + num_embeddings: int = 256, + embedding_dim: int = 256, + kernel_size: int = 3, layer1_channels: int = 256, layer2_channels: int = 256, - bias: bool=True, - dropout: float = 0.1 + bias: bool = True, + dropout: float = 0.1, ): super().__init__() self.embed = nn.Embedding( num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes - embedding_dim=embedding_dim, # + embedding_dim=embedding_dim, # ) - - assert embedding_dim == layer1_channels # for depth wise convolution + + assert embedding_dim == layer1_channels # for depth wise convolution self.conv = nn.Sequential( nn.Conv1d( embedding_dim, - layer1_channels, # depthwise convolution + layer1_channels, # depthwise convolution kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, @@ -705,7 +727,7 @@ class TextEmbedding(nn.Module): nn.Conv1d( layer1_channels, layer2_channels, - kernel_size=1, # pointwise convolution + kernel_size=1, # pointwise convolution stride=1, padding=0, bias=True, @@ -713,10 +735,10 @@ class TextEmbedding(nn.Module): Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0), nn.ReLU(), ) - + self.out_norm = BiasNorm(layer2_channels) self.dropout = Dropout3(dropout, shared_dim=1) - + def forward(self, text: torch.Tensor) -> torch.Tensor: """Forward function of the text embedding @@ -725,51 +747,57 @@ class TextEmbedding(nn.Module): Returns: The embeddings of text (T,N,C) """ - text = self.embed(text) # (T,N,C) - - #src = text - text = text.permute(1,2,0) # (T,N,C) -> (N,C,T) + text = self.embed(text) # (T,N,C) + + # src = text + text = text.permute(1, 2, 0) # (T,N,C) -> (N,C,T) text = self.conv(text) - text = text.permute(2,0,1) # (N,C,T) -> (T,N,C) - #src = src + text - + text = text.permute(2, 0, 1) # (N,C,T) -> (T,N,C) + # src = src + text + text = self.out_norm(text) text = self.dropout(text) - + return text - + def get_text_encoder(params: AttributeDict) -> nn.Module: # Return a text encoder if params.text_encoder_type == "BERT": from transformers import BertModel + # This is a BERT-base-cased logging.info("Loading pre-trained BERT-base-cased as text encoder") model = BertModel.from_pretrained("bert-base-cased") elif params.text_encoder_type == "DistilBERT": from transformers import DistilBertModel + # This is a DistilBERT-base-cased logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") model = DistilBertModel.from_pretrained("distilbert-base-cased") else: raise ValueError() - + return model + def get_tokenizer(params: AttributeDict): - + if params.text_encoder_type == "BERT": from transformers import BertTokenizer + # This is a BERT-base-cased - tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") elif params.text_encoder_type == "DistilBERT": from transformers import DistilBertTokenizer - tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') + + tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased") else: raise ValueError() - + return tokenizer + def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Zipformer2( output_downsampling_factor=2, @@ -789,7 +817,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), - memory_dim=768, # This is fixed as the BERT base model is 768-D + memory_dim=768, # This is fixed as the BERT base model is 768-D memory_layer=params.memory_layer, memory_dropout_rate=params.memory_dropout_rate, ) @@ -812,7 +840,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, - context_dim=4 * 768 if params.context_injection else -1, # the output dim of text encoder + context_dim=4 * 768 + if params.context_injection + else -1, # the output dim of text encoder context_injection=params.context_injection, ) return joiner @@ -821,23 +851,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) - text_encoder = get_text_encoder(params) # This should be a cased BERT base model + text_encoder = get_text_encoder(params) # This should be a cased BERT base model num_param = sum([p.numel() for p in text_encoder.parameters()]) logging.info(f"Num params in text encoder: {num_param}") decoder = get_decoder_model(params) joiner = get_joiner_model(params) - - if params.context_injection: - from context_fuser import ContextFuser, SelfAttContextFuser - context_fuser = SelfAttContextFuser( - embed_dim=768, - nhead=4, - context_dropout_rate=params.context_dropout_rate, - ) - logging.info(f"Using context injection!") - logging.info(context_fuser) - else: - context_fuser = None model = PromptedTransducer( encoder_embed=encoder_embed, @@ -851,12 +869,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, text_encoder_type=params.text_encoder_type, text_encoder_adapter=params.text_encoder_adapter, - context_fuser=context_fuser, + context_fuser=None, ) - - if params.text_encoder_adapter: - logging.info(f"Using adapter for BERT encoder") - logging.info(f"{model.text_encoder_adapter}") + return model @@ -978,13 +993,14 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) + def _encode_texts_as_bytes_with_tokenizer( - pre_texts: List[str], + pre_texts: List[str], style_texts: List[str], tokenizer, device: torch.device, - max_len: int=500, - no_limit: bool=False + max_len: int = 500, + no_limit: bool = False, ) -> Tuple[Dict, Tensor]: """ Encode texts as bytes and then integer tensors. @@ -992,36 +1008,39 @@ def _encode_texts_as_bytes_with_tokenizer( """ batch_size = len(pre_texts) max_len = min(max_len, 500) - + if no_limit: allowed_lens = [5000 - len(s) for s in style_texts] else: allowed_lens = [1000 - len(s) for s in style_texts] - truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)] - combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)] - + truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)] + combined_text = [ + style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size) + ] + encoded_style_texts = tokenizer( style_texts, - return_tensors='pt', + return_tensors="pt", padding=True, truncation=True, return_length=True, max_length=max_len, ) style_lens = encoded_style_texts["length"].to(device) - + # Use tokenizer to prepare input for text encoder encoded_inputs = tokenizer( combined_text, - return_tensors='pt', + return_tensors="pt", padding=True, truncation=True, return_length=True, max_length=max_len, ).to(device) - + return encoded_inputs, style_lens - + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -1048,11 +1067,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -1067,20 +1082,24 @@ def compute_loss( texts = batch["supervisions"]["text"] pre_texts = batch["supervisions"]["pre_text"] - style_texts = batch["supervisions"]["style_text"] # the style texts are in gt format + style_texts = batch["supervisions"][ + "style_text" + ] # the style texts are in gt format transform_ids = batch["supervisions"]["transform_ids"] - + # This is to replace full-width symbols with half-width symbols texts = [train_text_normalization(t) for t in texts] pre_texts = [train_text_normalization(t) for t in pre_texts] style_texts = [train_text_normalization(t) for t in style_texts] - - y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space + + y = sp.encode( + texts, out_type=int + ) # sp.encode treats consecutive space as a single space y = k2.RaggedTensor(y).to(device) - + if params.forced_upper_pre_text: pre_texts = [upper_only_alpha(p) for p in pre_texts] - + # only shuffle the pre_text and style texts if during training, and use style prompt if is_training: # randomly shuffle&mask the pre_text @@ -1089,38 +1108,40 @@ def compute_loss( p=params.pre_text_shuffle_prob, p_mask=params.prompt_mask_prob, ) - + if params.use_style_prompt: - if random.random() < 0.5: + if random.random() < 0.5: # randomly shuffle the style_text # now the style_texts are all in gt format style_texts = random_shuffle_subset( style_texts, p=params.style_text_shuffle_prob, - p_mask=params.prompt_mask_prob - ) - + p_mask=params.prompt_mask_prob, + ) + assert len(transform_ids) == len(style_texts) - + for i in range(len(style_texts)): - t = transform_ids[i] # get the transform id + t = transform_ids[i] # get the transform id style_texts[i] = style_transforms[t](style_texts[i]) if not params.use_style_prompt: - style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt - + style_texts = [ + "" for _ in style_texts + ] # use empty string for style texts if don't use style prompt + if random.random() < 0.05: logging.info(f"Pre texts: {pre_texts[0]}") logging.info(f"Ref texts: {texts[0]}") logging.info(f"Style texts: {style_texts[0]}") - + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( pre_texts=pre_texts, style_texts=style_texts, tokenizer=tokenizer, device=device, ) - + if random.random() < 0.02: logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ") @@ -1157,9 +1178,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -1352,9 +1371,7 @@ def train_one_epoch( # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: @@ -1376,11 +1393,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -1391,9 +1404,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -1401,10 +1412,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1452,11 +1460,15 @@ def run(rank, world_size, args): setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") - + if not params.use_style_prompt: - if params.pre_text_shuffle_prob == 0.0: - logging.info(f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}") - logging.info("If style prompt is not used, you should be careful when shuffling the pre_text within the same batch") + if params.pre_text_shuffle_prob == 0.0: + logging.info( + f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}" + ) + logging.info( + "If style prompt is not used, you should be careful when shuffling the pre_text within the same batch" + ) logging.info("Hard set this probability to 0.0!") params.pre_text_shuffle_prob = 0.0 @@ -1504,10 +1516,12 @@ def run(rank, world_size, args): if params.freeze_text_encoder: freeze_modules = ["text_encoder"] - logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer") + logging.info( + "Freeze the parameters of text encoder and don't include them in the optimizer" + ) else: freeze_modules = [] - + optimizer = ScaledAdam( get_parameter_groups_with_lrs( model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules @@ -1533,7 +1547,7 @@ def run(rank, world_size, args): if params.print_diagnostics: args.max_duration = 100 opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1543,7 +1557,7 @@ def run(rank, world_size, args): libriheavy = LibriHeavyAsrDataModule(args) train_cuts = libriheavy.train_cuts() - + def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -1586,10 +1600,14 @@ def run(rank, world_size, args): sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None - - text_sampling_func = triplet_text_sampling + + if params.use_context_list: + text_sampling_func = triplet_text_sampling_with_context_list + else: + text_sampling_func = triplet_text_sampling + logging.info(f"Text sampling: {text_sampling_func}") - + train_dl = libriheavy.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict, @@ -1599,18 +1617,17 @@ def run(rank, world_size, args): # For fair comparison, use fixed sampling in valid dataloaders valid_cuts = libriheavy.dev_cuts() valid_dl = libriheavy.valid_dataloaders( - valid_cuts, - text_sampling_func=naive_triplet_text_sampling + valid_cuts, text_sampling_func=naive_triplet_text_sampling ) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # sp=sp, - # params=params, - # ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: