mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix decoding issues
This commit is contained in:
parent
3dbbc29429
commit
19b5b86f9b
@ -22,6 +22,4 @@ python3 ./whisper_llm_zh/decode.py \
|
||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||
--epoch 1 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn False
|
@ -46,12 +46,14 @@ import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
#from tn.chinese.normalizer import Normalizer
|
||||
#from whisper.normalizers import BasicTextNormalizer
|
||||
#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
|
||||
#from zhconv import convert
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from model import EncoderProjector, SPEECH_LLM
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
@ -188,6 +190,13 @@ def get_parser():
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
|
||||
@ -247,8 +256,8 @@ def decode_one_batch(
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
dtype = torch.float16
|
||||
device = model.device
|
||||
dtype = torch.float32
|
||||
device = model.llm.device
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
@ -270,19 +279,17 @@ def decode_one_batch(
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
|
||||
messages = []
|
||||
for i, text in enumerate(texts):
|
||||
message = [
|
||||
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
messages.append(message)
|
||||
messages = [[
|
||||
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]] * len(feature)
|
||||
|
||||
input_ids, attention_mask = preprocess(
|
||||
messages, tokenizer, max_len=128
|
||||
)
|
||||
|
||||
model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device))
|
||||
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
# hyps = remove_punctuation(hyps)
|
||||
|
@ -12,7 +12,7 @@ class EncoderProjector(nn.Module):
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(llm_dim, llm_dim)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
@ -125,6 +125,7 @@ class SPEECH_LLM(nn.Module):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
# downsample encoder_outs by 4
|
||||
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
@ -141,20 +142,20 @@ class SPEECH_LLM(nn.Module):
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs:
|
||||
**kwargs
|
||||
):
|
||||
|
||||
encoder_outs = self.encoder(fbank)
|
||||
# downsample encoder_outs by 4
|
||||
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
speech_features = speech_features.to(torch.float16)
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask
|
||||
)
|
||||
|
||||
model_outputs = self.llm.generate(
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
@ -164,11 +165,9 @@ class SPEECH_LLM(nn.Module):
|
||||
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||
length_penalty=kwargs.get("length_penalty", 1.0),
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
bos_token_id=self.llm.config.bos_token_id,
|
||||
eos_token_id=self.llm.config.eos_token_id,
|
||||
pad_token_id=self.tllm.config.pad_token_id
|
||||
pad_token_id=self.llm.config.pad_token_id
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
||||
|
@ -278,4 +278,6 @@ class MultiDataset:
|
||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
return aishell_test_cuts
|
||||
return {
|
||||
"aishell_test": aishell_test_cuts,
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user