mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
refine decoding method
This commit is contained in:
parent
3ad075af60
commit
0c02da82ac
@ -52,12 +52,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
log "stage 2: "
|
log "stage 2: "
|
||||||
python3 ./slam_omni/decode.py \
|
python3 ./slam_omni/decode.py \
|
||||||
--max-duration 80 \
|
--max-duration 80 \
|
||||||
--exp-dir slam_omni/exp_test_whisper_qwen2_1.5B \
|
--exp-dir slam_omni/exp_speech2text \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name models/qwen \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
|
--method pure_text_sampling \
|
||||||
|
--use-lora True # --on-the-fly-feats True
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
log "stage 2: "
|
||||||
|
python3 ./slam_omni/decode.py \
|
||||||
|
--max-duration 80 \
|
||||||
|
--exp-dir slam_omni/exp_speech2text \
|
||||||
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
|
--epoch 999 --avg 1 \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--method pure_text_sampling_original_0.5B \
|
||||||
--use-lora False # --on-the-fly-feats True
|
--use-lora False # --on-the-fly-feats True
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
@ -52,7 +52,6 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
@ -60,13 +59,12 @@ import whisper
|
|||||||
from data_module import AsrDataModule
|
from data_module import AsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from model import SPEECH_LLM, EncoderProjector
|
from model import SPEECH_LLM, EncoderProjector
|
||||||
# from data_module import MultiDataset
|
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
from train import DEFAULT_SPEECH_TOKEN
|
from train import DEFAULT_SPEECH_TOKEN
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -76,7 +74,6 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def average_checkpoints(
|
def average_checkpoints(
|
||||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -133,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-path-or-name",
|
"--llm-path-or-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="/workspace/asr/Qwen1.5-0.5B-Chat",
|
default="",
|
||||||
help="Path or name of the large language model.",
|
help="Path or name of the large language model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -264,7 +261,6 @@ def decode_one_batch(
|
|||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
max_len: int = 128,
|
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Preprocesses the data for supervised fine-tuning."""
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
texts = []
|
texts = []
|
||||||
@ -277,8 +273,7 @@ def decode_one_batch(
|
|||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
chat_template=TEMPLATE,
|
chat_template=TEMPLATE,
|
||||||
padding="longest",
|
padding="longest",
|
||||||
max_length=max_len,
|
truncation=False,
|
||||||
truncation=True,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
max_len_texts = max([len(text) for text in texts])
|
max_len_texts = max([len(text) for text in texts])
|
||||||
@ -318,18 +313,38 @@ def decode_one_batch(
|
|||||||
2,
|
2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# supervisions = batch["supervisions"]
|
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||||
# feature_len = supervisions["num_frames"]
|
|
||||||
# feature_len = feature_len.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
messages = [
|
# messages = [
|
||||||
[
|
# [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||||
{"role": "assistant", "content": ""},
|
# {"role": "assistant", "content": ""},
|
||||||
|
# ]
|
||||||
|
# ] * len(feature)
|
||||||
|
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||||
|
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
||||||
|
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||||
|
messages = []
|
||||||
|
for i, total_round in enumerate(chat_rounds):
|
||||||
|
message = []
|
||||||
|
if total_round > 1:
|
||||||
|
history_question_answer = history_contexts[i].split('USER:')
|
||||||
|
history_question_answer = [item for item in history_question_answer if item]
|
||||||
|
for j in range(total_round - 1):
|
||||||
|
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||||
|
question_answer = history_question_answer[j].split('ASSISTANT:')
|
||||||
|
message += [
|
||||||
|
{"role": "user", "content": question_answer[0].strip()},
|
||||||
|
{"role": "assistant", "content": question_answer[1].strip()}
|
||||||
|
]
|
||||||
|
message += [
|
||||||
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
# {"role": "user", "content": f"{last_questions[i]}"},
|
||||||
|
{"role": "assistant", "content": ""}
|
||||||
]
|
]
|
||||||
] * len(feature)
|
messages.append(message)
|
||||||
|
|
||||||
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
|
input_ids, attention_mask = preprocess(messages, tokenizer)
|
||||||
|
|
||||||
generated_ids = model.decode(
|
generated_ids = model.decode(
|
||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
@ -422,7 +437,7 @@ def decode_dataset(
|
|||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_text = normalize_text_alimeeting(ref_text)
|
# ref_text = normalize_text_alimeeting(ref_text)
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
print(f"ref: {ref_text}")
|
print(f"ref: {ref_text}")
|
||||||
print(f"hyp: {''.join(hyp_words)}")
|
print(f"hyp: {''.join(hyp_words)}")
|
||||||
@ -449,7 +464,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
@ -459,7 +474,7 @@ def save_results(
|
|||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
@ -475,7 +490,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -499,8 +514,10 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
|
||||||
|
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
setup_logger(
|
setup_logger(
|
||||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
@ -241,27 +241,41 @@ class SPEECH_LLM(nn.Module):
|
|||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
(
|
(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask,
|
|
||||||
_,
|
_,
|
||||||
position_ids,
|
_,
|
||||||
|
_,
|
||||||
) = self._merge_input_ids_with_speech_features(
|
) = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask
|
speech_features, inputs_embeds, input_ids, attention_mask
|
||||||
)
|
)
|
||||||
generated_ids = self.llm.generate(
|
generated_ids = self.llm.generate(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
||||||
num_beams=kwargs.get("num_beams", 1),
|
num_beams=kwargs.get("num_beams", 1),
|
||||||
do_sample=kwargs.get("do_sample", False),
|
do_sample=kwargs.get("do_sample", True),
|
||||||
min_length=kwargs.get("min_length", 1),
|
min_length=kwargs.get("min_length", 1),
|
||||||
top_p=kwargs.get("top_p", 1.0),
|
top_p=kwargs.get("top_p", 0.5),
|
||||||
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
top_k=kwargs.get("top_k", 20),
|
||||||
length_penalty=kwargs.get("length_penalty", 1.0),
|
repetition_penalty=kwargs.get("repetition_penalty", 1.1),
|
||||||
temperature=kwargs.get("temperature", 1.0),
|
temperature=kwargs.get("temperature", 0.7),
|
||||||
bos_token_id=self.llm.config.bos_token_id,
|
bos_token_id=self.llm.config.bos_token_id,
|
||||||
eos_token_id=self.llm.config.eos_token_id,
|
eos_token_id=self.llm.config.eos_token_id,
|
||||||
pad_token_id=self.llm.config.pad_token_id,
|
pad_token_id=self.llm.config.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# generated_ids = self.llm.generate(
|
||||||
|
# inputs_embeds=inputs_embeds,
|
||||||
|
# max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||||
|
# num_beams=kwargs.get("num_beams", 1),
|
||||||
|
# do_sample=kwargs.get("do_sample", False),
|
||||||
|
# min_length=kwargs.get("min_length", 1),
|
||||||
|
# top_p=kwargs.get("top_p", 1.0),
|
||||||
|
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||||
|
# temperature=kwargs.get("temperature", 1.0),
|
||||||
|
# length_penalty=kwargs.get("length_penalty", 1.0),
|
||||||
|
# bos_token_id=self.llm.config.bos_token_id,
|
||||||
|
# eos_token_id=self.llm.config.eos_token_id,
|
||||||
|
# pad_token_id=self.llm.config.pad_token_id,
|
||||||
|
# )
|
||||||
return generated_ids
|
return generated_ids
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user