mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
add loss type
This commit is contained in:
parent
e52581e69b
commit
4a29430349
@ -239,7 +239,8 @@ fi
|
||||
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "stage 14: Client"
|
||||
datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa)
|
||||
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
|
||||
datasets=(openbookqa commoneval)
|
||||
for dataset in ${datasets[@]}; do
|
||||
# sd-qa should use usa split
|
||||
if [ $dataset == "sd-qa" ]; then
|
||||
@ -250,17 +251,16 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
echo $dataset $split_name
|
||||
python3 ./qwen_omni/client.py \
|
||||
--subset-name $dataset --split-name $split_name \
|
||||
--output-dir test_result
|
||||
--output-dir result_adapter_librispeech_kl_div_qa_template
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
log "stage 15: Training Speech2Speech Model, adaptor only"
|
||||
exp_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 600 \
|
||||
--max-duration 700 \
|
||||
--enable-musan False \
|
||||
--audio-key audio --text-key continuation \
|
||||
--exp-dir $exp_dir \
|
||||
@ -271,7 +271,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--dataset-format speech_continuation \
|
||||
--start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \
|
||||
--start-epoch 4 --pretrained-model-path $exp_dir/epoch-3/pytorch_model.bin \
|
||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||
fi
|
||||
|
||||
@ -321,3 +321,67 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||
log "stage 17: Server for adapter only speech continuation"
|
||||
exp_dir=./qwen_omni/exp_speech2text
|
||||
python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output False \
|
||||
--use-lora False --prompt-template continuation
|
||||
fi
|
||||
|
||||
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
log "stage 18: Training kl-div Speech2Speech Model, adaptor only"
|
||||
exp_dir=./qwen_omni/exp_speech2text_kl
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 700 \
|
||||
--enable-musan False \
|
||||
--audio-key audio --text-key continuation \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--dataset-format speech_continuation \
|
||||
--loss-type kl_div --dataset librispeech \
|
||||
--pretrained-model-path $exp_dir/checkpoint-1001/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-1001/sampler.pt \
|
||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||
fi
|
||||
|
||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
log "stage 19: Server for kl loss"
|
||||
exp_dir=./qwen_omni/exp_speech2text_kl
|
||||
python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output False \
|
||||
--use-lora False --prompt-template qa
|
||||
fi
|
||||
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "stage 20: Training Speech2Speech Model, adaptor + lora, second stage"
|
||||
exp_dir=./qwen_omni/exp_speech2text_kl_llm
|
||||
pretrained_dir=./qwen_omni/exp_speech2text_kl
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 200 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path $pretrained_dir/epoch-10/pytorch_model.bin \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False --dataset-format vocalnet
|
||||
fi
|
||||
|
@ -64,6 +64,8 @@ class SPEECH_LLM(nn.Module):
|
||||
encoder_projector: nn.Module,
|
||||
codec_lm: nn.Module = None,
|
||||
codec_lm_padding_side: str = "left",
|
||||
teacher_llm: nn.Module = None,
|
||||
kl_temperature: float = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
@ -92,6 +94,9 @@ class SPEECH_LLM(nn.Module):
|
||||
multidim_average="global",
|
||||
ignore_index=IGNORE_TOKEN_ID,
|
||||
)
|
||||
if teacher_llm is not None:
|
||||
self.teacher_llm = teacher_llm
|
||||
self.kl_temperature = kl_temperature
|
||||
|
||||
def _merge_input_ids_with_speech_features(
|
||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||
@ -256,6 +261,67 @@ class SPEECH_LLM(nn.Module):
|
||||
)
|
||||
return model_outputs.loss, acc
|
||||
|
||||
def forward_kl_div(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
teacher_input_ids: torch.LongTensor = None,
|
||||
teacher_attention_mask: torch.Tensor = None,
|
||||
teacher_labels: torch.LongTensor = None,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
|
||||
)
|
||||
|
||||
teacher_outputs = self.teacher_llm(
|
||||
input_ids=teacher_input_ids,
|
||||
attention_mask=teacher_attention_mask,
|
||||
)
|
||||
|
||||
kl_loss = torch.nn.functional.kl_div(
|
||||
torch.nn.functional.log_softmax(
|
||||
model_outputs.logits[labels != -100] / self.kl_temperature,
|
||||
dim=-1,
|
||||
),
|
||||
torch.nn.functional.softmax(
|
||||
teacher_outputs.logits[teacher_labels != -100] / self.kl_temperature,
|
||||
dim=-1,
|
||||
),
|
||||
reduction="batchmean",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
teacher_preds = torch.argmax(teacher_outputs.logits, -1)
|
||||
acc = compute_accuracy(
|
||||
preds.detach()[:, :-1],
|
||||
labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
acc_teacher = compute_accuracy(
|
||||
teacher_preds.detach()[:, :-1],
|
||||
teacher_labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
return kl_loss, acc, acc_teacher
|
||||
|
||||
def forward_with_speech_output(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
|
@ -21,6 +21,12 @@ def get_args():
|
||||
default=None,
|
||||
help="Checkpoint name or path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@ -59,8 +65,23 @@ model, tokenizer = get_model(args)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda")
|
||||
if args.prompt_template is None:
|
||||
template = f"{DEFAULT_SPEECH_TOKEN}"
|
||||
elif args.prompt_template == "qa":
|
||||
template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||
elif args.prompt_template == "continuation":
|
||||
template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||
elif args.prompt_template == "asr":
|
||||
template = (
|
||||
f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}"
|
||||
)
|
||||
elif args.prompt_template == "mt":
|
||||
template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}"
|
||||
else:
|
||||
raise ValueError(f"Invalid prompt template: {args.prompt_template}")
|
||||
print("Using template:", template)
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "user", "content": template},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
|
@ -74,8 +74,8 @@ from transformers import (
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_rank,
|
||||
get_local_rank,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -234,6 +234,21 @@ def get_parser():
|
||||
default="slam_omni",
|
||||
help="The format of the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="multi_en",
|
||||
help="The name of the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--loss-type",
|
||||
type=str,
|
||||
default="ce",
|
||||
help="The type of loss to use.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
add_model_arguments(parser)
|
||||
|
||||
@ -335,6 +350,22 @@ def process_batch_vocalnet(batch: dict):
|
||||
return messages, answer_cosyvoice_speech_token
|
||||
|
||||
|
||||
def process_batch_text_vocalnet(batch: dict):
|
||||
pass
|
||||
answers = batch["supervisions"]["text"]
|
||||
answer_cosyvoice_speech_token = [
|
||||
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
messages = []
|
||||
for i in range(len(answers)):
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": answers[i]},
|
||||
]
|
||||
messages.append(message)
|
||||
return messages, answer_cosyvoice_speech_token
|
||||
|
||||
|
||||
def process_batch_speech_continuation(batch: dict):
|
||||
messages = []
|
||||
for i in range(len(batch["supervisions"]["text"])):
|
||||
@ -350,6 +381,131 @@ def process_batch_speech_continuation(batch: dict):
|
||||
return messages
|
||||
|
||||
|
||||
def process_batch_text_continuation(batch: dict):
|
||||
messages = []
|
||||
for i in range(len(batch["supervisions"]["text"])):
|
||||
transcript = batch["supervisions"]["cut"][i].custom["text"]
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}",
|
||||
},
|
||||
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
|
||||
]
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="longest", # FIX me change padding to longest
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(f"Remove too long text, {messages} ")
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||
# WAR: TODO FIXME check qwen3
|
||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
|
||||
def preprocess_teacher(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="longest", # FIX me change padding to longest
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(f"Remove too long text, {messages} ")
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 2 to skip: 'assistant', '\n'
|
||||
# WAR: TODO FIXME check qwen3
|
||||
# THIS IS THE ONLY DIFFERENCE FROM preprocess
|
||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||
target_ids[row, col] = default_speech_token_id
|
||||
# remove default_speech_token_id from target_ids and input_ids
|
||||
batch_size = target_ids.size(0)
|
||||
|
||||
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
|
||||
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
@ -374,72 +530,6 @@ def compute_loss(
|
||||
Returns:
|
||||
Return a tuple of two elements. The first element is the loss tensor.
|
||||
"""
|
||||
# For the uneven-sized batch, the total duration after padding would possibly
|
||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||
# we simply drop the last few shortest samples, so that the retained total frames
|
||||
# (after padding) would not exceed `allowed_max_frames`:
|
||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
||||
# We set allowed_excess_duration_ratio=0.1.
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="longest", # FIX me change padding to longest
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(f"Remove too long text, {messages} ")
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||
# WAR: TODO FIXME check qwen3
|
||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
# max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||
# allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
# batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
|
||||
@ -452,8 +542,12 @@ def compute_loss(
|
||||
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
|
||||
elif params.dataset_format == "vocalnet":
|
||||
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
|
||||
if params.loss_type == "kl_div":
|
||||
messages_text = process_batch_text_vocalnet(batch)
|
||||
elif params.dataset_format == "speech_continuation":
|
||||
messages = process_batch_speech_continuation(batch)
|
||||
if params.loss_type == "kl_div":
|
||||
messages_text = process_batch_text_continuation(batch)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||
|
||||
@ -464,12 +558,30 @@ def compute_loss(
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if not params.enable_speech_output:
|
||||
loss, acc = model(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
)
|
||||
if params.loss_type == "ce":
|
||||
loss, acc = model(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
)
|
||||
elif params.loss_type == "kl_div":
|
||||
(
|
||||
teacher_input_ids,
|
||||
teacher_attention_mask,
|
||||
teacher_target_ids,
|
||||
) = preprocess_teacher(messages_text, tokenizer)
|
||||
loss, acc, acc_teacher = model.forward_kl_div(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
teacher_input_ids=teacher_input_ids.to(device),
|
||||
teacher_attention_mask=teacher_attention_mask.to(device),
|
||||
teacher_labels=teacher_target_ids.to(device),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {params.loss_type}")
|
||||
else:
|
||||
(
|
||||
text_loss,
|
||||
@ -498,6 +610,8 @@ def compute_loss(
|
||||
info["acc"] = (
|
||||
acc * info["frames"]
|
||||
) # WAR: to avoid normalization by the number of frames
|
||||
if params.loss_type == "kl_div":
|
||||
info["acc_teacher"] = acc_teacher * info["frames"]
|
||||
if params.enable_speech_output:
|
||||
info["codec_acc"] = codec_acc * info["frames"]
|
||||
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
|
||||
@ -820,6 +934,17 @@ def run(rank, world_size, args):
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
else:
|
||||
codec_lm = None
|
||||
if params.loss_type == "kl_div":
|
||||
teacher_llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
for name, param in teacher_llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
teacher_llm.eval()
|
||||
else:
|
||||
teacher_llm = None
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
@ -827,6 +952,7 @@ def run(rank, world_size, args):
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||
teacher_llm=teacher_llm,
|
||||
)
|
||||
|
||||
if params.pretrained_model_path:
|
||||
@ -834,7 +960,9 @@ def run(rank, world_size, args):
|
||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||
# set params.batch_idx_train according to the checkpoint name
|
||||
if "checkpoint-" in params.pretrained_model_path:
|
||||
params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1])
|
||||
params.batch_idx_train = int(
|
||||
params.pretrained_model_path.split("-")[-1].split("/")[0]
|
||||
)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -893,9 +1021,14 @@ def run(rank, world_size, args):
|
||||
train_cuts = data_module.train_cuts_en_vocalnet()
|
||||
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||
elif params.dataset_format == "speech_continuation":
|
||||
train_cuts = data_module.train_cuts_ultravox()
|
||||
# train_cuts = data_module.train_cuts_gigaspeech()
|
||||
# train_cuts = data_module.train_cuts_librispeech()
|
||||
if params.dataset == "multi_en":
|
||||
train_cuts = data_module.train_cuts_ultravox()
|
||||
elif params.dataset == "librispeech":
|
||||
train_cuts = data_module.train_cuts_librispeech()
|
||||
elif params.dataset == "gigaspeech":
|
||||
train_cuts = data_module.train_cuts_gigaspeech()
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {params.dataset}")
|
||||
valid_cuts = data_module.valid_cuts_ultravox()
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user