refine decoding method

This commit is contained in:
root 2025-04-15 06:53:20 +00:00
parent 3ad075af60
commit 0c02da82ac
3 changed files with 79 additions and 32 deletions

View File

@ -52,12 +52,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: " log "stage 2: "
python3 ./slam_omni/decode.py \ python3 ./slam_omni/decode.py \
--max-duration 80 \ --max-duration 80 \
--exp-dir slam_omni/exp_test_whisper_qwen2_1.5B \ --exp-dir slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/qwen \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--method pure_text_sampling \
--use-lora True # --on-the-fly-feats True
fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 2: "
python3 ./slam_omni/decode.py \
--max-duration 80 \
--exp-dir slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method pure_text_sampling_original_0.5B \
--use-lora False # --on-the-fly-feats True --use-lora False # --on-the-fly-feats True
fi fi

View File

@ -52,7 +52,6 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
@ -60,13 +59,12 @@ import whisper
from data_module import AsrDataModule from data_module import AsrDataModule
from lhotse.cut import Cut from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector from model import SPEECH_LLM, EncoderProjector
# from data_module import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import DEFAULT_SPEECH_TOKEN from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -76,7 +74,6 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
def average_checkpoints( def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu") filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict: ) -> dict:
@ -133,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--llm-path-or-name", "--llm-path-or-name",
type=str, type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat", default="",
help="Path or name of the large language model.", help="Path or name of the large language model.",
) )
@ -264,7 +261,6 @@ def decode_one_batch(
def preprocess( def preprocess(
messages, messages,
tokenizer: transformers.PreTrainedTokenizer, tokenizer: transformers.PreTrainedTokenizer,
max_len: int = 128,
) -> Dict: ) -> Dict:
"""Preprocesses the data for supervised fine-tuning.""" """Preprocesses the data for supervised fine-tuning."""
texts = [] texts = []
@ -277,8 +273,7 @@ def decode_one_batch(
add_generation_prompt=False, add_generation_prompt=False,
chat_template=TEMPLATE, chat_template=TEMPLATE,
padding="longest", padding="longest",
max_length=max_len, truncation=False,
truncation=True,
) )
) )
max_len_texts = max([len(text) for text in texts]) max_len_texts = max([len(text) for text in texts])
@ -318,18 +313,38 @@ def decode_one_batch(
2, 2,
) )
# supervisions = batch["supervisions"] chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
# feature_len = supervisions["num_frames"]
# feature_len = feature_len.to(device, dtype=dtype)
messages = [ # messages = [
[ # [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""}, # {"role": "assistant", "content": ""},
# ]
# ] * len(feature)
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
messages = []
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split('USER:')
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:')
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()}
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""}
] ]
] * len(feature) messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128) input_ids, attention_mask = preprocess(messages, tokenizer)
generated_ids = model.decode( generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
@ -422,7 +437,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text) # ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split() ref_words = ref_text.split()
print(f"ref: {ref_text}") print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}") print(f"hyp: {''.join(hyp_words)}")
@ -449,7 +464,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
@ -459,7 +474,7 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
@ -475,7 +490,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -499,8 +514,10 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
params.log_dir.mkdir(parents=True, exist_ok=True)
setup_logger( setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
) )
logging.info("Decoding started") logging.info("Decoding started")

View File

@ -241,27 +241,41 @@ class SPEECH_LLM(nn.Module):
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
( (
inputs_embeds, inputs_embeds,
attention_mask,
_, _,
position_ids, _,
_,
) = self._merge_input_ids_with_speech_features( ) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask speech_features, inputs_embeds, input_ids, attention_mask
) )
generated_ids = self.llm.generate( generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 200), max_new_tokens=kwargs.get("max_new_tokens", 1024),
num_beams=kwargs.get("num_beams", 1), num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False), do_sample=kwargs.get("do_sample", True),
min_length=kwargs.get("min_length", 1), min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0), top_p=kwargs.get("top_p", 0.5),
repetition_penalty=kwargs.get("repetition_penalty", 1.0), top_k=kwargs.get("top_k", 20),
length_penalty=kwargs.get("length_penalty", 1.0), repetition_penalty=kwargs.get("repetition_penalty", 1.1),
temperature=kwargs.get("temperature", 1.0), temperature=kwargs.get("temperature", 0.7),
bos_token_id=self.llm.config.bos_token_id, bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id, eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id, pad_token_id=self.llm.config.pad_token_id,
) )
# generated_ids = self.llm.generate(
# inputs_embeds=inputs_embeds,
# max_new_tokens=kwargs.get("max_new_tokens", 200),
# num_beams=kwargs.get("num_beams", 1),
# do_sample=kwargs.get("do_sample", False),
# min_length=kwargs.get("min_length", 1),
# top_p=kwargs.get("top_p", 1.0),
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
# temperature=kwargs.get("temperature", 1.0),
# length_penalty=kwargs.get("length_penalty", 1.0),
# bos_token_id=self.llm.config.bos_token_id,
# eos_token_id=self.llm.config.eos_token_id,
# pad_token_id=self.llm.config.pad_token_id,
# )
return generated_ids return generated_ids