From 3dbbc294296d2d3940b7e5745ec98ab8ec630fad Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Tue, 4 Jun 2024 22:56:57 -0700 Subject: [PATCH] add decode file --- egs/speech_llm/ASR_LLM/debug.sh | 27 ++ egs/speech_llm/ASR_LLM/run.sh | 24 +- .../ASR_LLM/whisper_llm_zh/decode.py | 263 +++++++++--------- .../ASR_LLM/whisper_llm_zh/model.py | 53 +++- .../ASR_LLM/whisper_llm_zh/multi_dataset.py | 13 +- .../ASR_LLM/whisper_llm_zh/train.py | 3 + 6 files changed, 230 insertions(+), 153 deletions(-) create mode 100755 egs/speech_llm/ASR_LLM/debug.sh diff --git a/egs/speech_llm/ASR_LLM/debug.sh b/egs/speech_llm/ASR_LLM/debug.sh new file mode 100755 index 000000000..3f83cd1ef --- /dev/null +++ b/egs/speech_llm/ASR_LLM/debug.sh @@ -0,0 +1,27 @@ + +export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm +# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html +# pip install -r whisper/requirements.txt +export CUDA_VISIBLE_DEVICES=0,1 +# torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \ +# --max-duration 80 \ +# --exp-dir ./whisper_llm_zh/exp_test \ +# --speech-encoder-path-or-name tiny \ +# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ +# --manifest-dir data/fbank \ +# --deepspeed \ +# --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ +# --use-flash-attn False + + + +python3 ./whisper_llm_zh/decode.py \ + --max-duration 80 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name tiny \ + --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ + --epoch 1 --avg 1 \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn False \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/run.sh b/egs/speech_llm/ASR_LLM/run.sh index 123f74c03..b73fda468 100755 --- a/egs/speech_llm/ASR_LLM/run.sh +++ b/egs/speech_llm/ASR_LLM/run.sh @@ -1,14 +1,16 @@ -export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm -# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html -# pip install -r whisper/requirements.txt -export CUDA_VISIBLE_DEVICES=0,1 -torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \ - --max-duration 80 \ - --exp-dir ./whisper_llm_zh/exp_test \ - --speech-encoder-path-or-name tiny \ - --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ +export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall +#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html +#pip install -r whisper_llm_zh/requirements.txt +#export CUDA_VISIBLE_DEVICES=0,1 + +whisper_path=/workspace/asr/icefall_asr_multi-hans-zh_whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt +llm_path=/workspace/asr/Qwen1.5-7B-Chat +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ + --max-duration 100 \ + --exp-dir ./whisper_llm_zh/exp_qwen_7b \ + --speech-encoder-path-or-name $whisper_path \ + --llm-path-or-name $llm_path \ --manifest-dir data/fbank \ --deepspeed \ - --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ - --use-flash-attn False \ No newline at end of file + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index f758f546c..666e02508 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -30,15 +30,6 @@ python3 ./whisper/decode.py \ --epoch 999 --avg 1 \ --beam-size 10 --max-duration 50 -# Command for decoding using pretrained models (before fine-tuning): - -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ - --epoch -1 --avg 1 \ - --remove-whisper-encoder-input-length-restriction False \ - --beam-size 10 --max-duration 50 - """ import argparse @@ -70,7 +61,7 @@ from icefall.utils import ( str2bool, write_error_stats, ) - +from train import DEFAULT_SPEECH_TOKEN def average_checkpoints( filenames: List[Path], device: torch.device = torch.device("cpu") @@ -123,48 +114,27 @@ def average_checkpoints( return avg +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--llm-path-or-name", + type=str, + default="/workspace/asr/Qwen1.5-0.5B-Chat", + help="Path or name of the large language model.", + ) -def remove_punctuation(text: str or List[str]): - """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py - - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings without any punctuation. - """ - punctuation = "!,.;:?、!,。;:?《》 " - if isinstance(text, str): - text = re.sub(r"[{}]+".format(punctuation), "", text).strip() - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = re.sub(r"[{}]+".format(punctuation), "", t).strip() - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type {type(text)}") - - -def to_simple(text: str or List[str]): - """Convert traditional Chinese to simplified Chinese. - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings converted to simplified Chinese. - """ - if isinstance(text, str): - text = convert(text, "zh-cn") - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = convert(t, "zh-cn") - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type{type(text)}") + parser.add_argument( + "--speech-encoder-path-or-name", + type=str, + default="whisper-large-v2", + help="Path or name of the speech encoder.", + ) + parser.add_argument( + "--encoder-projector-ds-rate", + type=int, + default=4, + help="Downsample rate for the encoder projector.", + ) def get_parser(): parser = argparse.ArgumentParser( @@ -211,15 +181,6 @@ def get_parser(): help="The experiment dir", ) - parser.add_argument( - "--model-name", - type=str, - default="large-v2", - choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], - help="""The model name to use. - """, - ) - parser.add_argument( "--remove-whisper-encoder-input-length-restriction", type=str2bool, @@ -227,13 +188,7 @@ def get_parser(): help="replace whisper encoder forward method to remove input length restriction", ) - parser.add_argument( - "--use-distill-whisper", - type=str2bool, - default=False, - help="Whether to use architecture of distill whisper.", - ) - + add_model_arguments(parser) return parser @@ -249,6 +204,7 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, + tokenizer: AutoTokenizer, batch: dict, ) -> Dict[str, List[List[int]]]: """Decode one batch and return the result in a dict. The dict has the @@ -266,8 +222,33 @@ def decode_one_batch( Returns: Return a dict, whose key may be "beam-search". """ + def preprocess( + messages, + tokenizer: transformers.PreTrainedTokenizer, + max_len: int = 128, + ) -> Dict: + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + add_generation_prompt=False, + padding="max_length", + max_length=max_len, + truncation=True, + ) + ) + + input_ids = torch.tensor(texts, dtype=torch.int) + + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + return input_ids, attention_mask + dtype = torch.float16 - device = torch.device("cuda") + device = model.device feature = batch["inputs"] assert feature.ndim == 3 @@ -288,12 +269,25 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_len = supervisions["num_frames"] feature_len = feature_len.to(device, dtype=dtype) - results = model.decode(feature, params.decoding_options) - hyps = [result.text for result in results] - hyps = remove_punctuation(hyps) - hyps = to_simple(hyps) - hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + messages = [] + for i, text in enumerate(texts): + message = [ + {"role": "system", "content": "你是一个能处理音频的助手。"}, + {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": ""}, + ] + messages.append(message) + input_ids, attention_mask = preprocess( + messages, tokenizer, max_len=128 + ) + + model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device)) + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + # hyps = remove_punctuation(hyps) + # hyps = to_simple(hyps) + # hyps = [params.normalizer.normalize(hyp) for hyp in hyps] print(hyps) return {"beam-search": hyps} @@ -302,6 +296,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, + tokenizer: AutoTokenizer, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -370,6 +365,7 @@ def decode_dataset( params=params, model=model, batch=batch, + tokenizer=tokenizer, ) for lm_scale, hyps in hyps_dict.items(): @@ -455,16 +451,6 @@ def main(): f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" ) - options = whisper.DecodingOptions( - task="transcribe", - language="zh", - without_timestamps=True, - beam_size=params.beam_size, - ) - params.decoding_options = options - params.cleaner = BasicTextNormalizer() - params.normalizer = Normalizer() - logging.info("Decoding started") logging.info(params) @@ -476,49 +462,68 @@ def main(): if params.remove_whisper_encoder_input_length_restriction: replace_whisper_encoder_forward() - if params.use_distill_whisper: - replace_whisper_decoder_forward() - model = whisper.load_model(params.model_name, "cpu") - if params.epoch > 0: - if params.avg > 1: - start = params.epoch - params.avg - assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - # deepspeed converted checkpoint only contains model state_dict - filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" - for epoch in range(start, params.epoch + 1) - ] - model.load_state_dict(average_checkpoints(filenames)) - else: - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - # save checkpoints - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(model.state_dict(), filename) + + whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") + speech_encoder = whisper_model.encoder + speech_encoder_dim = whisper_model.dims.n_audio_state + + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + torch_dtype=torch.bfloat16 + + else: + attn_implementation = "eager" + torch_dtype=torch.float16 + + llm = AutoModelForCausalLM.from_pretrained( + params.llm_path_or_name, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype, + ) + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + tokenizer.padding_side = 'left' + special_tokens_dict = { + "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] + } + tokenizer.add_special_tokens(special_tokens_dict) + llm.config.pad_token_id = tokenizer.pad_token_id + llm.config.bos_token_id = tokenizer.bos_token_id + llm.config.eos_token_id = tokenizer.eos_token_id + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) + + encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size) + + model = SPEECH_LLM( + speech_encoder, + llm, + encoder_projector, + ) + + + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + assert "model" not in checkpoint + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + model.load_state_dict(average_checkpoints(filenames), strict=False) + + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=False) else: - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - if "model" not in checkpoint: - model.load_state_dict(checkpoint, strict=True) - else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) @@ -534,13 +539,14 @@ def main(): # Keep only utterances with duration in 30 seconds # if c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) return False return True - test_sets_cuts = multi_dataset.test_cuts() + # test_sets_cuts = multi_dataset.test_cuts() + test_sets_cuts = multi_dataset.aishell_test_cuts() test_sets = test_sets_cuts.keys() test_dls = [ @@ -553,6 +559,7 @@ def main(): dl=test_dl, params=params, model=model, + tokenizer=tokenizer, ) save_results(params=params, test_set_name=test_set, results_dict=results_dict) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index e5096cb4d..10cc18abf 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -38,7 +38,7 @@ class SPEECH_LLM(nn.Module): self.encoder_projector = encoder_projector self.encoder_outputs_downsample_rate = 4 - def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels): + def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None): num_speechs, speech_len, embed_dim = speech_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)) @@ -132,18 +132,45 @@ class SPEECH_LLM(nn.Module): speech_features, inputs_embeds, input_ids, attention_mask, labels ) - # outputs = self.llm( - # attention_mask=attention_mask, - # position_ids=position_ids, - # past_key_values=past_key_values, - # inputs_embeds=inputs_embeds, - # use_cache=use_cache, - # output_attentions=output_attentions, - # output_hidden_states=output_hidden_states, - # return_dict=return_dict, - # ) - # logits = outputs[0] - model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) return model_outputs + + + def decode(self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + **kwargs: + ): + + encoder_outs = self.encoder(fbank) + # downsample encoder_outs by 4 + encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] + speech_features = self.encoder_projector(encoder_outs) + + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features( + speech_features, inputs_embeds, input_ids, attention_mask + ) + + model_outputs = self.llm.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=kwargs.get("max_new_tokens", 200), + num_beams=kwargs.get("num_beams", 1), + do_sample=kwargs.get("do_sample", False), + min_length=kwargs.get("min_length", 1), + top_p=kwargs.get("top_p", 1.0), + repetition_penalty=kwargs.get("repetition_penalty", 1.0), + length_penalty=kwargs.get("length_penalty", 1.0), + temperature=kwargs.get("temperature", 1.0), + attention_mask=attention_mask, + position_ids=position_ids, + bos_token_id=self.llm.config.bos_token_id, + eos_token_id=self.llm.config.eos_token_id, + pad_token_id=self.tllm.config.pad_token_id + ) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids) + ] + return generated_ids \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index f6eacab01..2813bb80d 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -267,4 +267,15 @@ class MultiDataset: self.fbank_dir / "aishell_cuts_dev.jsonl.gz" ) - return aishell_dev_cuts \ No newline at end of file + return aishell_dev_cuts + + def aishell_test_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + + # AISHELL + logging.info("Loading Aishell set in lazy mode") + aishell_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_test.jsonl.gz" + ) + + return aishell_test_cuts \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index b3674be01..3315a5b53 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -725,13 +725,16 @@ def run(rank, world_size, args): if params.use_flash_attn: attn_implementation = "flash_attention_2" + torch_dtype=torch.bfloat16 else: attn_implementation = "eager" + torch_dtype=torch.float16 llm = AutoModelForCausalLM.from_pretrained( params.llm_path_or_name, attn_implementation=attn_implementation, + torch_dtype=torch_dtype, ) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer.padding_side = 'left'