mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
fix decoding issues
This commit is contained in:
parent
3dbbc29429
commit
19b5b86f9b
@ -22,6 +22,4 @@ python3 ./whisper_llm_zh/decode.py \
|
|||||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||||
--epoch 1 --avg 1 \
|
--epoch 1 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
|
||||||
--use-flash-attn False
|
--use-flash-attn False
|
@ -46,12 +46,14 @@ import whisper
|
|||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from tn.chinese.normalizer import Normalizer
|
#from tn.chinese.normalizer import Normalizer
|
||||||
from whisper.normalizers import BasicTextNormalizer
|
#from whisper.normalizers import BasicTextNormalizer
|
||||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from zhconv import convert
|
#from zhconv import convert
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from model import EncoderProjector, SPEECH_LLM
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -188,6 +190,13 @@ def get_parser():
|
|||||||
help="replace whisper encoder forward method to remove input length restriction",
|
help="replace whisper encoder forward method to remove input length restriction",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-flash-attn",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use flash attention.",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -247,8 +256,8 @@ def decode_one_batch(
|
|||||||
|
|
||||||
return input_ids, attention_mask
|
return input_ids, attention_mask
|
||||||
|
|
||||||
dtype = torch.float16
|
dtype = torch.float32
|
||||||
device = model.device
|
device = model.llm.device
|
||||||
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -270,19 +279,17 @@ def decode_one_batch(
|
|||||||
feature_len = supervisions["num_frames"]
|
feature_len = supervisions["num_frames"]
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
feature_len = feature_len.to(device, dtype=dtype)
|
||||||
|
|
||||||
messages = []
|
messages = [[
|
||||||
for i, text in enumerate(texts):
|
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||||
message = [
|
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||||
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
{"role": "assistant", "content": ""},
|
||||||
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
]] * len(feature)
|
||||||
{"role": "assistant", "content": ""},
|
|
||||||
]
|
|
||||||
messages.append(message)
|
|
||||||
input_ids, attention_mask = preprocess(
|
input_ids, attention_mask = preprocess(
|
||||||
messages, tokenizer, max_len=128
|
messages, tokenizer, max_len=128
|
||||||
)
|
)
|
||||||
|
|
||||||
model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device))
|
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
# hyps = remove_punctuation(hyps)
|
# hyps = remove_punctuation(hyps)
|
||||||
|
@ -125,6 +125,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
# downsample encoder_outs by 4
|
# downsample encoder_outs by 4
|
||||||
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||||
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
@ -141,20 +142,20 @@ class SPEECH_LLM(nn.Module):
|
|||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor = None,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: torch.Tensor = None,
|
attention_mask: torch.Tensor = None,
|
||||||
**kwargs:
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
# downsample encoder_outs by 4
|
# downsample encoder_outs by 4
|
||||||
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
|
||||||
|
|
||||||
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
speech_features = speech_features.to(torch.float16)
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
|
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask
|
speech_features, inputs_embeds, input_ids, attention_mask
|
||||||
)
|
)
|
||||||
|
generated_ids = self.llm.generate(
|
||||||
model_outputs = self.llm.generate(
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||||
num_beams=kwargs.get("num_beams", 1),
|
num_beams=kwargs.get("num_beams", 1),
|
||||||
@ -164,11 +165,9 @@ class SPEECH_LLM(nn.Module):
|
|||||||
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||||
length_penalty=kwargs.get("length_penalty", 1.0),
|
length_penalty=kwargs.get("length_penalty", 1.0),
|
||||||
temperature=kwargs.get("temperature", 1.0),
|
temperature=kwargs.get("temperature", 1.0),
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
bos_token_id=self.llm.config.bos_token_id,
|
bos_token_id=self.llm.config.bos_token_id,
|
||||||
eos_token_id=self.llm.config.eos_token_id,
|
eos_token_id=self.llm.config.eos_token_id,
|
||||||
pad_token_id=self.tllm.config.pad_token_id
|
pad_token_id=self.llm.config.pad_token_id
|
||||||
)
|
)
|
||||||
generated_ids = [
|
generated_ids = [
|
||||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
||||||
|
@ -278,4 +278,6 @@ class MultiDataset:
|
|||||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
return aishell_test_cuts
|
return {
|
||||||
|
"aishell_test": aishell_test_cuts,
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user