mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix repeat bos and pad id
This commit is contained in:
parent
80677a55f8
commit
559f9e2def
@ -68,7 +68,6 @@ from transformers import (
|
|||||||
Qwen2Config,
|
Qwen2Config,
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils import ( # filter_uneven_sized_batch,
|
from utils import ( # filter_uneven_sized_batch,
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
@ -306,8 +305,7 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
|
|
||||||
def extract_text_and_speech_token(
|
def extract_text_and_speech_token(
|
||||||
batch: dict,
|
batch: dict, enable_speech_output: bool
|
||||||
enable_speech_output: bool
|
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]:
|
) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]:
|
||||||
"""
|
"""
|
||||||
Extracts messages and speech tokens from a batch based on the dataset format.
|
Extracts messages and speech tokens from a batch based on the dataset format.
|
||||||
@ -342,28 +340,34 @@ def extract_text_and_speech_token(
|
|||||||
# The 'prompt_template' argument to the function seems unused if we determine it here.
|
# The 'prompt_template' argument to the function seems unused if we determine it here.
|
||||||
# For now, I will proceed assuming the internal logic dictates the template.
|
# For now, I will proceed assuming the internal logic dictates the template.
|
||||||
# If the function argument `prompt_template` was meant to be the default, this logic would need adjustment.
|
# If the function argument `prompt_template` was meant to be the default, this logic would need adjustment.
|
||||||
current_prompt_template = "speech_qa" # Default value for prompt_template for the current item
|
current_prompt_template = (
|
||||||
|
"speech_qa" # Default value for prompt_template for the current item
|
||||||
|
)
|
||||||
target = answers[i]
|
target = answers[i]
|
||||||
message_list_item = []
|
message_list_item = []
|
||||||
|
|
||||||
custom_data = batch["supervisions"]["cut"][i].custom
|
custom_data = batch["supervisions"]["cut"][i].custom
|
||||||
|
|
||||||
if 'round' in custom_data:
|
if "round" in custom_data:
|
||||||
# slam_omni format dataset
|
# slam_omni format dataset
|
||||||
# For 'round' type, the current interaction's user prompt will use current_prompt_template ("speech_qa")
|
# For 'round' type, the current interaction's user prompt will use current_prompt_template ("speech_qa")
|
||||||
current_question_with_history = custom_data["question"]
|
current_question_with_history = custom_data["question"]
|
||||||
total_round = custom_data["round"]
|
total_round = custom_data["round"]
|
||||||
history_context = current_question_with_history.rsplit("<USER>:", 1)[0].strip()
|
history_context = current_question_with_history.rsplit("<USER>:", 1)[
|
||||||
|
0
|
||||||
|
].strip()
|
||||||
if total_round > 1:
|
if total_round > 1:
|
||||||
history_question_answer = history_context.split("USER:")
|
history_question_answer = history_context.split("USER:")
|
||||||
history_question_answer = [item for item in history_question_answer if item]
|
history_question_answer = [
|
||||||
|
item for item in history_question_answer if item
|
||||||
|
]
|
||||||
for j in range(total_round - 1):
|
for j in range(total_round - 1):
|
||||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||||
message_list_item += [
|
message_list_item += [
|
||||||
{"role": "user", "content": question_answer[0].strip()},
|
{"role": "user", "content": question_answer[0].strip()},
|
||||||
{"role": "assistant", "content": question_answer[1].strip()},
|
{"role": "assistant", "content": question_answer[1].strip()},
|
||||||
]
|
]
|
||||||
elif 'continuation' in custom_data:
|
elif "continuation" in custom_data:
|
||||||
# see https://huggingface.co/datasets/fixie-ai/librispeech_asr
|
# see https://huggingface.co/datasets/fixie-ai/librispeech_asr
|
||||||
ASR_PROBABILITY = 0.3
|
ASR_PROBABILITY = 0.3
|
||||||
if random.random() < ASR_PROBABILITY:
|
if random.random() < ASR_PROBABILITY:
|
||||||
@ -382,6 +386,7 @@ def extract_text_and_speech_token(
|
|||||||
|
|
||||||
return messages, speech_tokens
|
return messages, speech_tokens
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
@ -432,12 +437,11 @@ def preprocess(
|
|||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
return input_ids, attention_mask, target_ids
|
return input_ids, attention_mask, target_ids
|
||||||
|
|
||||||
|
|
||||||
def process_batch_text_continuation(batch: dict):
|
def process_batch_text_continuation(batch: dict):
|
||||||
messages = []
|
messages = []
|
||||||
transcripts = batch["supervisions"]["text"]
|
transcripts = batch["supervisions"]["text"]
|
||||||
continuations = [
|
continuations = [cut.custom["continuation"] for cut in batch["supervisions"]["cut"]]
|
||||||
cut.custom["continuation"] for cut in batch["supervisions"]["cut"]
|
|
||||||
]
|
|
||||||
for i in range(len(transcripts)):
|
for i in range(len(transcripts)):
|
||||||
message = [
|
message = [
|
||||||
{
|
{
|
||||||
@ -449,6 +453,7 @@ def process_batch_text_continuation(batch: dict):
|
|||||||
messages.append(message)
|
messages.append(message)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def preprocess_teacher(
|
def preprocess_teacher(
|
||||||
messages,
|
messages,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
@ -827,7 +832,6 @@ def get_model(params):
|
|||||||
if not params.unfreeze_llm:
|
if not params.unfreeze_llm:
|
||||||
for name, param in llm.named_parameters():
|
for name, param in llm.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
else:
|
|
||||||
if params.use_lora:
|
if params.use_lora:
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=64,
|
r=64,
|
||||||
@ -848,6 +852,9 @@ def get_model(params):
|
|||||||
llm.print_trainable_parameters()
|
llm.print_trainable_parameters()
|
||||||
|
|
||||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||||
|
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||||
|
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||||
DEFAULT_SPEECH_TOKEN
|
DEFAULT_SPEECH_TOKEN
|
||||||
)
|
)
|
||||||
@ -884,7 +891,9 @@ def get_model(params):
|
|||||||
elif params.speech_tokenizer_type == "cosyvoice1":
|
elif params.speech_tokenizer_type == "cosyvoice1":
|
||||||
codec_vocab_size = 4096 + 4
|
codec_vocab_size = 4096 + 4
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown speech tokenizer type: {params.speech_tokenizer_type}")
|
raise ValueError(
|
||||||
|
f"Unknown speech tokenizer type: {params.speech_tokenizer_type}"
|
||||||
|
)
|
||||||
|
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
vocab_size=codec_vocab_size,
|
vocab_size=codec_vocab_size,
|
||||||
@ -921,10 +930,14 @@ def get_model(params):
|
|||||||
if params.pretrained_model_path or params.last_stage_model_path:
|
if params.pretrained_model_path or params.last_stage_model_path:
|
||||||
if params.pretrained_model_path is None:
|
if params.pretrained_model_path is None:
|
||||||
checkpoint = torch.load(params.last_stage_model_path, map_location="cpu")
|
checkpoint = torch.load(params.last_stage_model_path, map_location="cpu")
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(
|
||||||
|
checkpoint, strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(
|
||||||
|
checkpoint, strict=False
|
||||||
|
)
|
||||||
# set params.batch_idx_train according to the checkpoint name
|
# set params.batch_idx_train according to the checkpoint name
|
||||||
if "checkpoint-" in params.pretrained_model_path:
|
if "checkpoint-" in params.pretrained_model_path:
|
||||||
params.batch_idx_train = int(
|
params.batch_idx_train = int(
|
||||||
@ -940,6 +953,7 @@ def get_model(params):
|
|||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def run(rank, world_size, args):
|
def run(rank, world_size, args):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user