refine decoding method

This commit is contained in:
root 2025-04-15 06:53:20 +00:00
parent 3ad075af60
commit 0c02da82ac
3 changed files with 79 additions and 32 deletions

View File

@ -52,12 +52,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: "
python3 ./slam_omni/decode.py \
--max-duration 80 \
--exp-dir slam_omni/exp_test_whisper_qwen2_1.5B \
--exp-dir slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/qwen \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method pure_text_sampling \
--use-lora True # --on-the-fly-feats True
fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 2: "
python3 ./slam_omni/decode.py \
--max-duration 80 \
--exp-dir slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method pure_text_sampling_original_0.5B \
--use-lora False # --on-the-fly-feats True
fi

View File

@ -52,7 +52,6 @@ from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
import transformers
@ -60,13 +59,12 @@ import whisper
from data_module import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
# from data_module import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
@ -76,7 +74,6 @@ from icefall.utils import (
write_error_stats,
)
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
@ -133,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
default="",
help="Path or name of the large language model.",
)
@ -264,7 +261,6 @@ def decode_one_batch(
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int = 128,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
@ -277,8 +273,7 @@ def decode_one_batch(
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
max_length=max_len,
truncation=True,
truncation=False,
)
)
max_len_texts = max([len(text) for text in texts])
@ -318,18 +313,38 @@ def decode_one_batch(
2,
)
# supervisions = batch["supervisions"]
# feature_len = supervisions["num_frames"]
# feature_len = feature_len.to(device, dtype=dtype)
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
messages = [
[
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""},
# messages = [
# [
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
# {"role": "assistant", "content": ""},
# ]
# ] * len(feature)
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
messages = []
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split('USER:')
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:')
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()}
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""}
]
] * len(feature)
messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
input_ids, attention_mask = preprocess(messages, tokenizer)
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
@ -422,7 +437,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
# ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
@ -449,7 +464,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
@ -459,7 +474,7 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
@ -475,7 +490,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@ -499,8 +514,10 @@ def main():
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
params.log_dir.mkdir(parents=True, exist_ok=True)
setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
)
logging.info("Decoding started")

View File

@ -241,27 +241,41 @@ class SPEECH_LLM(nn.Module):
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
_,
position_ids,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 200),
max_new_tokens=kwargs.get("max_new_tokens", 1024),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False),
do_sample=kwargs.get("do_sample", True),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
top_p=kwargs.get("top_p", 0.5),
top_k=kwargs.get("top_k", 20),
repetition_penalty=kwargs.get("repetition_penalty", 1.1),
temperature=kwargs.get("temperature", 0.7),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id,
)
# generated_ids = self.llm.generate(
# inputs_embeds=inputs_embeds,
# max_new_tokens=kwargs.get("max_new_tokens", 200),
# num_beams=kwargs.get("num_beams", 1),
# do_sample=kwargs.get("do_sample", False),
# min_length=kwargs.get("min_length", 1),
# top_p=kwargs.get("top_p", 1.0),
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
# temperature=kwargs.get("temperature", 1.0),
# length_penalty=kwargs.get("length_penalty", 1.0),
# bos_token_id=self.llm.config.bos_token_id,
# eos_token_id=self.llm.config.eos_token_id,
# pad_token_id=self.llm.config.pad_token_id,
# )
return generated_ids