mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
refine decoding method
This commit is contained in:
parent
3ad075af60
commit
0c02da82ac
@ -52,12 +52,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "stage 2: "
|
||||
python3 ./slam_omni/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir slam_omni/exp_test_whisper_qwen2_1.5B \
|
||||
--exp-dir slam_omni/exp_speech2text \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/qwen \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method pure_text_sampling \
|
||||
--use-lora True # --on-the-fly-feats True
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "stage 2: "
|
||||
python3 ./slam_omni/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir slam_omni/exp_speech2text \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method pure_text_sampling_original_0.5B \
|
||||
--use-lora False # --on-the-fly-feats True
|
||||
|
||||
fi
|
||||
|
@ -52,7 +52,6 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
@ -60,13 +59,12 @@ import whisper
|
||||
from data_module import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
# from data_module import MultiDataset
|
||||
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -76,7 +74,6 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
@ -133,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--llm-path-or-name",
|
||||
type=str,
|
||||
default="/workspace/asr/Qwen1.5-0.5B-Chat",
|
||||
default="",
|
||||
help="Path or name of the large language model.",
|
||||
)
|
||||
|
||||
@ -264,7 +261,6 @@ def decode_one_batch(
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_len: int = 128,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
@ -277,8 +273,7 @@ def decode_one_batch(
|
||||
add_generation_prompt=False,
|
||||
chat_template=TEMPLATE,
|
||||
padding="longest",
|
||||
max_length=max_len,
|
||||
truncation=True,
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
@ -318,18 +313,38 @@ def decode_one_batch(
|
||||
2,
|
||||
)
|
||||
|
||||
# supervisions = batch["supervisions"]
|
||||
# feature_len = supervisions["num_frames"]
|
||||
# feature_len = feature_len.to(device, dtype=dtype)
|
||||
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
messages = [
|
||||
[
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
{"role": "assistant", "content": ""},
|
||||
# messages = [
|
||||
# [
|
||||
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
# {"role": "assistant", "content": ""},
|
||||
# ]
|
||||
# ] * len(feature)
|
||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
messages = []
|
||||
for i, total_round in enumerate(chat_rounds):
|
||||
message = []
|
||||
if total_round > 1:
|
||||
history_question_answer = history_contexts[i].split('USER:')
|
||||
history_question_answer = [item for item in history_question_answer if item]
|
||||
for j in range(total_round - 1):
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
||||
message += [
|
||||
{"role": "user", "content": question_answer[0].strip()},
|
||||
{"role": "assistant", "content": question_answer[1].strip()}
|
||||
]
|
||||
message += [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
# {"role": "user", "content": f"{last_questions[i]}"},
|
||||
{"role": "assistant", "content": ""}
|
||||
]
|
||||
] * len(feature)
|
||||
messages.append(message)
|
||||
|
||||
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
|
||||
input_ids, attention_mask = preprocess(messages, tokenizer)
|
||||
|
||||
generated_ids = model.decode(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
@ -422,7 +437,7 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_text = normalize_text_alimeeting(ref_text)
|
||||
# ref_text = normalize_text_alimeeting(ref_text)
|
||||
ref_words = ref_text.split()
|
||||
print(f"ref: {ref_text}")
|
||||
print(f"hyp: {''.join(hyp_words)}")
|
||||
@ -449,7 +464,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
@ -459,7 +474,7 @@ def save_results(
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
@ -475,7 +490,7 @@ def save_results(
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -499,8 +514,10 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
|
||||
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
|
@ -241,27 +241,41 @@ class SPEECH_LLM(nn.Module):
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
_,
|
||||
position_ids,
|
||||
_,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask
|
||||
)
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
do_sample=kwargs.get("do_sample", False),
|
||||
do_sample=kwargs.get("do_sample", True),
|
||||
min_length=kwargs.get("min_length", 1),
|
||||
top_p=kwargs.get("top_p", 1.0),
|
||||
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||
length_penalty=kwargs.get("length_penalty", 1.0),
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
top_p=kwargs.get("top_p", 0.5),
|
||||
top_k=kwargs.get("top_k", 20),
|
||||
repetition_penalty=kwargs.get("repetition_penalty", 1.1),
|
||||
temperature=kwargs.get("temperature", 0.7),
|
||||
bos_token_id=self.llm.config.bos_token_id,
|
||||
eos_token_id=self.llm.config.eos_token_id,
|
||||
pad_token_id=self.llm.config.pad_token_id,
|
||||
)
|
||||
|
||||
# generated_ids = self.llm.generate(
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||
# num_beams=kwargs.get("num_beams", 1),
|
||||
# do_sample=kwargs.get("do_sample", False),
|
||||
# min_length=kwargs.get("min_length", 1),
|
||||
# top_p=kwargs.get("top_p", 1.0),
|
||||
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||
# temperature=kwargs.get("temperature", 1.0),
|
||||
# length_penalty=kwargs.get("length_penalty", 1.0),
|
||||
# bos_token_id=self.llm.config.bos_token_id,
|
||||
# eos_token_id=self.llm.config.eos_token_id,
|
||||
# pad_token_id=self.llm.config.pad_token_id,
|
||||
# )
|
||||
return generated_ids
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user