From 19b5b86f9be529e7b9920863aad8c0155a772d84 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 6 Jun 2024 12:25:10 +0800 Subject: [PATCH] fix decoding issues --- egs/speech_llm/ASR_LLM/debug.sh | 2 - .../ASR_LLM/whisper_llm_zh/decode.py | 39 +++++++++++-------- .../ASR_LLM/whisper_llm_zh/model.py | 15 ++++--- .../ASR_LLM/whisper_llm_zh/multi_dataset.py | 4 +- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/debug.sh b/egs/speech_llm/ASR_LLM/debug.sh index 3f83cd1ef..53f411870 100755 --- a/egs/speech_llm/ASR_LLM/debug.sh +++ b/egs/speech_llm/ASR_LLM/debug.sh @@ -22,6 +22,4 @@ python3 ./whisper_llm_zh/decode.py \ --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ --epoch 1 --avg 1 \ --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn False \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 666e02508..54a083c1e 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -46,12 +46,14 @@ import whisper from asr_datamodule import AsrDataModule from lhotse.cut import Cut from multi_dataset import MultiDataset -from tn.chinese.normalizer import Normalizer -from whisper.normalizers import BasicTextNormalizer -from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward +#from tn.chinese.normalizer import Normalizer +#from whisper.normalizers import BasicTextNormalizer +#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from zhconv import convert - +#from zhconv import convert +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer +from model import EncoderProjector, SPEECH_LLM from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.env import get_env_info from icefall.utils import ( @@ -188,6 +190,13 @@ def get_parser(): help="replace whisper encoder forward method to remove input length restriction", ) + parser.add_argument( + "--use-flash-attn", + type=str2bool, + default=True, + help="Whether to use flash attention.", + ) + add_model_arguments(parser) return parser @@ -247,8 +256,8 @@ def decode_one_batch( return input_ids, attention_mask - dtype = torch.float16 - device = model.device + dtype = torch.float32 + device = model.llm.device feature = batch["inputs"] assert feature.ndim == 3 @@ -270,19 +279,17 @@ def decode_one_batch( feature_len = supervisions["num_frames"] feature_len = feature_len.to(device, dtype=dtype) - messages = [] - for i, text in enumerate(texts): - message = [ - {"role": "system", "content": "你是一个能处理音频的助手。"}, - {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, - {"role": "assistant", "content": ""}, - ] - messages.append(message) + messages = [[ + {"role": "system", "content": "你是一个能处理音频的助手。"}, + {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": ""}, + ]] * len(feature) + input_ids, attention_mask = preprocess( messages, tokenizer, max_len=128 ) - model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device)) + generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # hyps = remove_punctuation(hyps) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index 10cc18abf..796cb2c9d 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -12,7 +12,7 @@ class EncoderProjector(nn.Module): self.relu = nn.ReLU() self.linear2 = nn.Linear(llm_dim, llm_dim) - def forward(self, x): + def forward(self, x): x = self.linear1(x) x = self.relu(x) x = self.linear2(x) @@ -125,6 +125,7 @@ class SPEECH_LLM(nn.Module): encoder_outs = self.encoder(fbank) # downsample encoder_outs by 4 encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] + speech_features = self.encoder_projector(encoder_outs) inputs_embeds = self.llm.get_input_embeddings()(input_ids) @@ -141,20 +142,20 @@ class SPEECH_LLM(nn.Module): fbank: torch.Tensor = None, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor = None, - **kwargs: + **kwargs ): encoder_outs = self.encoder(fbank) # downsample encoder_outs by 4 encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] + speech_features = self.encoder_projector(encoder_outs) - + speech_features = speech_features.to(torch.float16) inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask ) - - model_outputs = self.llm.generate( + generated_ids = self.llm.generate( inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_new_tokens", 200), num_beams=kwargs.get("num_beams", 1), @@ -164,11 +165,9 @@ class SPEECH_LLM(nn.Module): repetition_penalty=kwargs.get("repetition_penalty", 1.0), length_penalty=kwargs.get("length_penalty", 1.0), temperature=kwargs.get("temperature", 1.0), - attention_mask=attention_mask, - position_ids=position_ids, bos_token_id=self.llm.config.bos_token_id, eos_token_id=self.llm.config.eos_token_id, - pad_token_id=self.tllm.config.pad_token_id + pad_token_id=self.llm.config.pad_token_id ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index 2813bb80d..e4b148ea5 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -278,4 +278,6 @@ class MultiDataset: self.fbank_dir / "aishell_cuts_test.jsonl.gz" ) - return aishell_test_cuts \ No newline at end of file + return { + "aishell_test": aishell_test_cuts, + } \ No newline at end of file