fix decoding issues

This commit is contained in:
Yuekai Zhang 2024-06-06 12:25:10 +08:00
parent 3dbbc29429
commit 19b5b86f9b
4 changed files with 33 additions and 27 deletions

View File

@ -22,6 +22,4 @@ python3 ./whisper_llm_zh/decode.py \
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
--epoch 1 --avg 1 \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn False

View File

@ -46,12 +46,14 @@ import whisper
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
#from tn.chinese.normalizer import Normalizer
#from whisper.normalizers import BasicTextNormalizer
#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert
#from zhconv import convert
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from model import EncoderProjector, SPEECH_LLM
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
@ -188,6 +190,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
add_model_arguments(parser)
return parser
@ -247,8 +256,8 @@ def decode_one_batch(
return input_ids, attention_mask
dtype = torch.float16
device = model.device
dtype = torch.float32
device = model.llm.device
feature = batch["inputs"]
assert feature.ndim == 3
@ -270,19 +279,17 @@ def decode_one_batch(
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
messages = []
for i, text in enumerate(texts):
message = [
{"role": "system", "content": "你是一个能处理音频的助手。"},
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
messages.append(message)
messages = [[
{"role": "system", "content": "你是一个能处理音频的助手。"},
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]] * len(feature)
input_ids, attention_mask = preprocess(
messages, tokenizer, max_len=128
)
model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device))
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# hyps = remove_punctuation(hyps)

View File

@ -12,7 +12,7 @@ class EncoderProjector(nn.Module):
self.relu = nn.ReLU()
self.linear2 = nn.Linear(llm_dim, llm_dim)
def forward(self, x):
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
@ -125,6 +125,7 @@ class SPEECH_LLM(nn.Module):
encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
@ -141,20 +142,20 @@ class SPEECH_LLM(nn.Module):
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs:
**kwargs
):
encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
model_outputs = self.llm.generate(
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 200),
num_beams=kwargs.get("num_beams", 1),
@ -164,11 +165,9 @@ class SPEECH_LLM(nn.Module):
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
position_ids=position_ids,
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.tllm.config.pad_token_id
pad_token_id=self.llm.config.pad_token_id
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)

View File

@ -278,4 +278,6 @@ class MultiDataset:
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
)
return aishell_test_cuts
return {
"aishell_test": aishell_test_cuts,
}