mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
add loss type
This commit is contained in:
parent
e52581e69b
commit
4a29430349
@ -239,7 +239,8 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||||
log "stage 14: Client"
|
log "stage 14: Client"
|
||||||
datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa)
|
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
|
||||||
|
datasets=(openbookqa commoneval)
|
||||||
for dataset in ${datasets[@]}; do
|
for dataset in ${datasets[@]}; do
|
||||||
# sd-qa should use usa split
|
# sd-qa should use usa split
|
||||||
if [ $dataset == "sd-qa" ]; then
|
if [ $dataset == "sd-qa" ]; then
|
||||||
@ -250,17 +251,16 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
|||||||
echo $dataset $split_name
|
echo $dataset $split_name
|
||||||
python3 ./qwen_omni/client.py \
|
python3 ./qwen_omni/client.py \
|
||||||
--subset-name $dataset --split-name $split_name \
|
--subset-name $dataset --split-name $split_name \
|
||||||
--output-dir test_result
|
--output-dir result_adapter_librispeech_kl_div_qa_template
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||||
log "stage 15: Training Speech2Speech Model, adaptor only"
|
log "stage 15: Training Speech2Speech Model, adaptor only"
|
||||||
exp_dir=./qwen_omni/exp_speech2text
|
exp_dir=./qwen_omni/exp_speech2text
|
||||||
ngpu=2
|
ngpu=2
|
||||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
--max-duration 600 \
|
--max-duration 700 \
|
||||||
--enable-musan False \
|
--enable-musan False \
|
||||||
--audio-key audio --text-key continuation \
|
--audio-key audio --text-key continuation \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
@ -271,7 +271,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
|||||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--dataset-format speech_continuation \
|
--dataset-format speech_continuation \
|
||||||
--start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \
|
--start-epoch 4 --pretrained-model-path $exp_dir/epoch-3/pytorch_model.bin \
|
||||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -321,3 +321,67 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
|||||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
$train_cmd_args
|
$train_cmd_args
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||||
|
log "stage 17: Server for adapter only speech continuation"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2text
|
||||||
|
python3 ./qwen_omni/server.py \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
|
--checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--enable-speech-output False \
|
||||||
|
--use-lora False --prompt-template continuation
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||||
|
log "stage 18: Training kl-div Speech2Speech Model, adaptor only"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2text_kl
|
||||||
|
ngpu=2
|
||||||
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
|
--max-duration 700 \
|
||||||
|
--enable-musan False \
|
||||||
|
--audio-key audio --text-key continuation \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
|
--on-the-fly-feats True \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--dataset-format speech_continuation \
|
||||||
|
--loss-type kl_div --dataset librispeech \
|
||||||
|
--pretrained-model-path $exp_dir/checkpoint-1001/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-1001/sampler.pt \
|
||||||
|
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||||
|
log "stage 19: Server for kl loss"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2text_kl
|
||||||
|
python3 ./qwen_omni/server.py \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
|
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--enable-speech-output False \
|
||||||
|
--use-lora False --prompt-template qa
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
log "stage 20: Training Speech2Speech Model, adaptor + lora, second stage"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2text_kl_llm
|
||||||
|
pretrained_dir=./qwen_omni/exp_speech2text_kl
|
||||||
|
ngpu=2
|
||||||
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
|
--max-duration 200 \
|
||||||
|
--enable-musan False \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--pretrained-model-path $pretrained_dir/epoch-10/pytorch_model.bin \
|
||||||
|
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False --dataset-format vocalnet
|
||||||
|
fi
|
||||||
|
@ -64,6 +64,8 @@ class SPEECH_LLM(nn.Module):
|
|||||||
encoder_projector: nn.Module,
|
encoder_projector: nn.Module,
|
||||||
codec_lm: nn.Module = None,
|
codec_lm: nn.Module = None,
|
||||||
codec_lm_padding_side: str = "left",
|
codec_lm_padding_side: str = "left",
|
||||||
|
teacher_llm: nn.Module = None,
|
||||||
|
kl_temperature: float = 2.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
@ -92,6 +94,9 @@ class SPEECH_LLM(nn.Module):
|
|||||||
multidim_average="global",
|
multidim_average="global",
|
||||||
ignore_index=IGNORE_TOKEN_ID,
|
ignore_index=IGNORE_TOKEN_ID,
|
||||||
)
|
)
|
||||||
|
if teacher_llm is not None:
|
||||||
|
self.teacher_llm = teacher_llm
|
||||||
|
self.kl_temperature = kl_temperature
|
||||||
|
|
||||||
def _merge_input_ids_with_speech_features(
|
def _merge_input_ids_with_speech_features(
|
||||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||||
@ -256,6 +261,67 @@ class SPEECH_LLM(nn.Module):
|
|||||||
)
|
)
|
||||||
return model_outputs.loss, acc
|
return model_outputs.loss, acc
|
||||||
|
|
||||||
|
def forward_kl_div(
|
||||||
|
self,
|
||||||
|
fbank: torch.Tensor = None,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
labels: torch.LongTensor = None,
|
||||||
|
teacher_input_ids: torch.LongTensor = None,
|
||||||
|
teacher_attention_mask: torch.Tensor = None,
|
||||||
|
teacher_labels: torch.LongTensor = None,
|
||||||
|
):
|
||||||
|
encoder_outs = self.encoder(fbank)
|
||||||
|
|
||||||
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
(
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask,
|
||||||
|
labels,
|
||||||
|
_,
|
||||||
|
) = self._merge_input_ids_with_speech_features(
|
||||||
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
|
)
|
||||||
|
|
||||||
|
model_outputs = self.llm(
|
||||||
|
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
|
||||||
|
)
|
||||||
|
|
||||||
|
teacher_outputs = self.teacher_llm(
|
||||||
|
input_ids=teacher_input_ids,
|
||||||
|
attention_mask=teacher_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
kl_loss = torch.nn.functional.kl_div(
|
||||||
|
torch.nn.functional.log_softmax(
|
||||||
|
model_outputs.logits[labels != -100] / self.kl_temperature,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
|
torch.nn.functional.softmax(
|
||||||
|
teacher_outputs.logits[teacher_labels != -100] / self.kl_temperature,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
|
reduction="batchmean",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = torch.argmax(model_outputs.logits, -1)
|
||||||
|
teacher_preds = torch.argmax(teacher_outputs.logits, -1)
|
||||||
|
acc = compute_accuracy(
|
||||||
|
preds.detach()[:, :-1],
|
||||||
|
labels.detach()[:, 1:],
|
||||||
|
ignore_label=IGNORE_TOKEN_ID,
|
||||||
|
)
|
||||||
|
acc_teacher = compute_accuracy(
|
||||||
|
teacher_preds.detach()[:, :-1],
|
||||||
|
teacher_labels.detach()[:, 1:],
|
||||||
|
ignore_label=IGNORE_TOKEN_ID,
|
||||||
|
)
|
||||||
|
return kl_loss, acc, acc_teacher
|
||||||
|
|
||||||
def forward_with_speech_output(
|
def forward_with_speech_output(
|
||||||
self,
|
self,
|
||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor = None,
|
||||||
|
@ -21,6 +21,12 @@ def get_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Checkpoint name or path, default to %(default)r",
|
help="Checkpoint name or path, default to %(default)r",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template",
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@ -59,8 +65,23 @@ model, tokenizer = get_model(args)
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
if args.prompt_template is None:
|
||||||
|
template = f"{DEFAULT_SPEECH_TOKEN}"
|
||||||
|
elif args.prompt_template == "qa":
|
||||||
|
template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||||
|
elif args.prompt_template == "continuation":
|
||||||
|
template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||||
|
elif args.prompt_template == "asr":
|
||||||
|
template = (
|
||||||
|
f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}"
|
||||||
|
)
|
||||||
|
elif args.prompt_template == "mt":
|
||||||
|
template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid prompt template: {args.prompt_template}")
|
||||||
|
print("Using template:", template)
|
||||||
message = [
|
message = [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
{"role": "user", "content": template},
|
||||||
{"role": "assistant", "content": ""},
|
{"role": "assistant", "content": ""},
|
||||||
]
|
]
|
||||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
@ -74,8 +74,8 @@ from transformers import (
|
|||||||
from utils import ( # filter_uneven_sized_batch,
|
from utils import ( # filter_uneven_sized_batch,
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
get_rank,
|
|
||||||
get_local_rank,
|
get_local_rank,
|
||||||
|
get_rank,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -234,6 +234,21 @@ def get_parser():
|
|||||||
default="slam_omni",
|
default="slam_omni",
|
||||||
help="The format of the dataset.",
|
help="The format of the dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="multi_en",
|
||||||
|
help="The name of the dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--loss-type",
|
||||||
|
type=str,
|
||||||
|
default="ce",
|
||||||
|
help="The type of loss to use.",
|
||||||
|
)
|
||||||
|
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
@ -335,6 +350,22 @@ def process_batch_vocalnet(batch: dict):
|
|||||||
return messages, answer_cosyvoice_speech_token
|
return messages, answer_cosyvoice_speech_token
|
||||||
|
|
||||||
|
|
||||||
|
def process_batch_text_vocalnet(batch: dict):
|
||||||
|
pass
|
||||||
|
answers = batch["supervisions"]["text"]
|
||||||
|
answer_cosyvoice_speech_token = [
|
||||||
|
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
|
messages = []
|
||||||
|
for i in range(len(answers)):
|
||||||
|
message = [
|
||||||
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
{"role": "assistant", "content": answers[i]},
|
||||||
|
]
|
||||||
|
messages.append(message)
|
||||||
|
return messages, answer_cosyvoice_speech_token
|
||||||
|
|
||||||
|
|
||||||
def process_batch_speech_continuation(batch: dict):
|
def process_batch_speech_continuation(batch: dict):
|
||||||
messages = []
|
messages = []
|
||||||
for i in range(len(batch["supervisions"]["text"])):
|
for i in range(len(batch["supervisions"]["text"])):
|
||||||
@ -350,6 +381,131 @@ def process_batch_speech_continuation(batch: dict):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def process_batch_text_continuation(batch: dict):
|
||||||
|
messages = []
|
||||||
|
for i in range(len(batch["supervisions"]["text"])):
|
||||||
|
transcript = batch["supervisions"]["cut"][i].custom["text"]
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
|
||||||
|
]
|
||||||
|
messages.append(message)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
messages,
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
) -> Dict:
|
||||||
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
|
texts = []
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
texts.append(
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
msg,
|
||||||
|
tokenize=True,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
padding="longest", # FIX me change padding to longest
|
||||||
|
truncation=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(texts) != len(messages):
|
||||||
|
logging.warning(f"Remove too long text, {messages} ")
|
||||||
|
max_len_texts = max([len(text) for text in texts])
|
||||||
|
if tokenizer.padding_side == "right":
|
||||||
|
texts = [
|
||||||
|
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
texts = [
|
||||||
|
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
|
|
||||||
|
target_ids = input_ids.clone()
|
||||||
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
|
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||||
|
# first get the indices of the tokens
|
||||||
|
mask_prompt = True
|
||||||
|
if mask_prompt:
|
||||||
|
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||||
|
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||||
|
for i in range(mask_indices[0].size(0)):
|
||||||
|
row = mask_indices[0][i]
|
||||||
|
col = mask_indices[1][i]
|
||||||
|
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||||
|
# WAR: TODO FIXME check qwen3
|
||||||
|
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
return input_ids, attention_mask, target_ids
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_teacher(
|
||||||
|
messages,
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
) -> Dict:
|
||||||
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
|
texts = []
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
texts.append(
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
msg,
|
||||||
|
tokenize=True,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
padding="longest", # FIX me change padding to longest
|
||||||
|
truncation=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(texts) != len(messages):
|
||||||
|
logging.warning(f"Remove too long text, {messages} ")
|
||||||
|
max_len_texts = max([len(text) for text in texts])
|
||||||
|
if tokenizer.padding_side == "right":
|
||||||
|
texts = [
|
||||||
|
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
texts = [
|
||||||
|
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
|
|
||||||
|
target_ids = input_ids.clone()
|
||||||
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
|
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
|
||||||
|
# first get the indices of the tokens
|
||||||
|
mask_prompt = True
|
||||||
|
if mask_prompt:
|
||||||
|
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||||
|
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||||
|
for i in range(mask_indices[0].size(0)):
|
||||||
|
row = mask_indices[0][i]
|
||||||
|
col = mask_indices[1][i]
|
||||||
|
# + 2 to skip: 'assistant', '\n'
|
||||||
|
# WAR: TODO FIXME check qwen3
|
||||||
|
# THIS IS THE ONLY DIFFERENCE FROM preprocess
|
||||||
|
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||||
|
target_ids[row, col] = default_speech_token_id
|
||||||
|
# remove default_speech_token_id from target_ids and input_ids
|
||||||
|
batch_size = target_ids.size(0)
|
||||||
|
|
||||||
|
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
|
||||||
|
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
|
||||||
|
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
return input_ids, attention_mask, target_ids
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
@ -374,72 +530,6 @@ def compute_loss(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tuple of two elements. The first element is the loss tensor.
|
Return a tuple of two elements. The first element is the loss tensor.
|
||||||
"""
|
"""
|
||||||
# For the uneven-sized batch, the total duration after padding would possibly
|
|
||||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
|
||||||
# we simply drop the last few shortest samples, so that the retained total frames
|
|
||||||
# (after padding) would not exceed `allowed_max_frames`:
|
|
||||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
|
||||||
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
|
||||||
# We set allowed_excess_duration_ratio=0.1.
|
|
||||||
|
|
||||||
def preprocess(
|
|
||||||
messages,
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
|
||||||
) -> Dict:
|
|
||||||
"""Preprocesses the data for supervised fine-tuning."""
|
|
||||||
texts = []
|
|
||||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
texts.append(
|
|
||||||
tokenizer.apply_chat_template(
|
|
||||||
msg,
|
|
||||||
tokenize=True,
|
|
||||||
chat_template=TEMPLATE,
|
|
||||||
add_generation_prompt=False,
|
|
||||||
padding="longest", # FIX me change padding to longest
|
|
||||||
truncation=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if len(texts) != len(messages):
|
|
||||||
logging.warning(f"Remove too long text, {messages} ")
|
|
||||||
max_len_texts = max([len(text) for text in texts])
|
|
||||||
if tokenizer.padding_side == "right":
|
|
||||||
texts = [
|
|
||||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
|
||||||
for text in texts
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
texts = [
|
|
||||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
|
||||||
for text in texts
|
|
||||||
]
|
|
||||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
|
||||||
|
|
||||||
target_ids = input_ids.clone()
|
|
||||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
|
||||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
|
||||||
# first get the indices of the tokens
|
|
||||||
mask_prompt = True
|
|
||||||
if mask_prompt:
|
|
||||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
|
||||||
DEFAULT_SPEECH_TOKEN
|
|
||||||
)
|
|
||||||
mask_indices = torch.where(input_ids == default_speech_token_id)
|
|
||||||
for i in range(mask_indices[0].size(0)):
|
|
||||||
row = mask_indices[0][i]
|
|
||||||
col = mask_indices[1][i]
|
|
||||||
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
|
||||||
# WAR: TODO FIXME check qwen3
|
|
||||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
|
||||||
|
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
|
||||||
|
|
||||||
return input_ids, attention_mask, target_ids
|
|
||||||
|
|
||||||
# max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
|
||||||
# allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
|
||||||
# batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
|
||||||
@ -452,8 +542,12 @@ def compute_loss(
|
|||||||
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
|
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
|
||||||
elif params.dataset_format == "vocalnet":
|
elif params.dataset_format == "vocalnet":
|
||||||
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
|
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
|
||||||
|
if params.loss_type == "kl_div":
|
||||||
|
messages_text = process_batch_text_vocalnet(batch)
|
||||||
elif params.dataset_format == "speech_continuation":
|
elif params.dataset_format == "speech_continuation":
|
||||||
messages = process_batch_speech_continuation(batch)
|
messages = process_batch_speech_continuation(batch)
|
||||||
|
if params.loss_type == "kl_div":
|
||||||
|
messages_text = process_batch_text_continuation(batch)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||||
|
|
||||||
@ -464,12 +558,30 @@ def compute_loss(
|
|||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if not params.enable_speech_output:
|
if not params.enable_speech_output:
|
||||||
loss, acc = model(
|
if params.loss_type == "ce":
|
||||||
fbank=feature,
|
loss, acc = model(
|
||||||
input_ids=input_ids.to(device),
|
fbank=feature,
|
||||||
attention_mask=attention_mask.to(device),
|
input_ids=input_ids.to(device),
|
||||||
labels=target_ids.to(device),
|
attention_mask=attention_mask.to(device),
|
||||||
)
|
labels=target_ids.to(device),
|
||||||
|
)
|
||||||
|
elif params.loss_type == "kl_div":
|
||||||
|
(
|
||||||
|
teacher_input_ids,
|
||||||
|
teacher_attention_mask,
|
||||||
|
teacher_target_ids,
|
||||||
|
) = preprocess_teacher(messages_text, tokenizer)
|
||||||
|
loss, acc, acc_teacher = model.forward_kl_div(
|
||||||
|
fbank=feature,
|
||||||
|
input_ids=input_ids.to(device),
|
||||||
|
attention_mask=attention_mask.to(device),
|
||||||
|
labels=target_ids.to(device),
|
||||||
|
teacher_input_ids=teacher_input_ids.to(device),
|
||||||
|
teacher_attention_mask=teacher_attention_mask.to(device),
|
||||||
|
teacher_labels=teacher_target_ids.to(device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown loss type: {params.loss_type}")
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
text_loss,
|
text_loss,
|
||||||
@ -498,6 +610,8 @@ def compute_loss(
|
|||||||
info["acc"] = (
|
info["acc"] = (
|
||||||
acc * info["frames"]
|
acc * info["frames"]
|
||||||
) # WAR: to avoid normalization by the number of frames
|
) # WAR: to avoid normalization by the number of frames
|
||||||
|
if params.loss_type == "kl_div":
|
||||||
|
info["acc_teacher"] = acc_teacher * info["frames"]
|
||||||
if params.enable_speech_output:
|
if params.enable_speech_output:
|
||||||
info["codec_acc"] = codec_acc * info["frames"]
|
info["codec_acc"] = codec_acc * info["frames"]
|
||||||
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
|
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
|
||||||
@ -820,6 +934,17 @@ def run(rank, world_size, args):
|
|||||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||||
else:
|
else:
|
||||||
codec_lm = None
|
codec_lm = None
|
||||||
|
if params.loss_type == "kl_div":
|
||||||
|
teacher_llm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
params.llm_path_or_name,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
for name, param in teacher_llm.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
teacher_llm.eval()
|
||||||
|
else:
|
||||||
|
teacher_llm = None
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
@ -827,6 +952,7 @@ def run(rank, world_size, args):
|
|||||||
encoder_projector,
|
encoder_projector,
|
||||||
codec_lm,
|
codec_lm,
|
||||||
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||||
|
teacher_llm=teacher_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path:
|
||||||
@ -834,7 +960,9 @@ def run(rank, world_size, args):
|
|||||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||||
# set params.batch_idx_train according to the checkpoint name
|
# set params.batch_idx_train according to the checkpoint name
|
||||||
if "checkpoint-" in params.pretrained_model_path:
|
if "checkpoint-" in params.pretrained_model_path:
|
||||||
params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1])
|
params.batch_idx_train = int(
|
||||||
|
params.pretrained_model_path.split("-")[-1].split("/")[0]
|
||||||
|
)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -893,9 +1021,14 @@ def run(rank, world_size, args):
|
|||||||
train_cuts = data_module.train_cuts_en_vocalnet()
|
train_cuts = data_module.train_cuts_en_vocalnet()
|
||||||
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||||
elif params.dataset_format == "speech_continuation":
|
elif params.dataset_format == "speech_continuation":
|
||||||
train_cuts = data_module.train_cuts_ultravox()
|
if params.dataset == "multi_en":
|
||||||
# train_cuts = data_module.train_cuts_gigaspeech()
|
train_cuts = data_module.train_cuts_ultravox()
|
||||||
# train_cuts = data_module.train_cuts_librispeech()
|
elif params.dataset == "librispeech":
|
||||||
|
train_cuts = data_module.train_cuts_librispeech()
|
||||||
|
elif params.dataset == "gigaspeech":
|
||||||
|
train_cuts = data_module.train_cuts_gigaspeech()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown dataset: {params.dataset}")
|
||||||
valid_cuts = data_module.valid_cuts_ultravox()
|
valid_cuts = data_module.valid_cuts_ultravox()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user