mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add codec lm
This commit is contained in:
parent
458d697acc
commit
bdb60f6ddc
@ -96,3 +96,22 @@ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--use-lora True --unfreeze-llm True
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: "
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 40 \
|
||||
--enable-musan False \
|
||||
--exp-dir ./slam_omni/exp_speech2text \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn False \
|
||||
--use-lora True --unfreeze-llm False --enable-speech-output True
|
||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
# --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
|
||||
|
||||
fi
|
@ -58,11 +58,21 @@ class SPEECH_LLM(nn.Module):
|
||||
encoder: nn.Module,
|
||||
llm: nn.Module,
|
||||
encoder_projector: nn.Module,
|
||||
codec_lm: nn.Module = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.llm = llm
|
||||
self.encoder_projector = encoder_projector
|
||||
self.codec_lm = codec_lm
|
||||
if self.codec_lm:
|
||||
self.speech_token_projector = nn.Linear(
|
||||
self.llm.config.hidden_size, self.codec_lm.config.hidden_size
|
||||
)
|
||||
self.codec_lm_head = nn.Linear(
|
||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||
)
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def _merge_input_ids_with_speech_features(
|
||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||
@ -225,8 +235,112 @@ class SPEECH_LLM(nn.Module):
|
||||
labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
return model_outputs, acc
|
||||
return model_outputs.loss, acc
|
||||
|
||||
def forward_with_speech_output(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
speech_codec_ids: torch.LongTensor = None,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
# get the label start_index in inputs_embeds from labels
|
||||
text_label_start_index_list = []
|
||||
for i in range(labels.shape[0]):
|
||||
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
|
||||
text_label_start_index_list.append(text_label_start_index)
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
||||
)
|
||||
text_loss = model_outputs.loss
|
||||
|
||||
# prepare codec lm inputs
|
||||
audio_codes_lens = torch.tensor(
|
||||
[len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
# print(audio_codes_lens, "audio_codes_lens")
|
||||
max_len_speech_codec = max(audio_codes_lens)
|
||||
delay_step = 2
|
||||
audio_codes = torch.full(
|
||||
(inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1),
|
||||
self.codec_lm.config.pad_token_id,
|
||||
dtype=torch.int64,
|
||||
device=input_ids.device
|
||||
)
|
||||
audio_labels = audio_codes.clone()
|
||||
|
||||
for i, speech_codec in enumerate(speech_codec_ids):
|
||||
text_label_start_index = text_label_start_index_list[i]
|
||||
speech_codec = torch.tensor(
|
||||
speech_codec, dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
# print(inputs_embeds[i, text_label_start_index], "2333 test")
|
||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
|
||||
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
|
||||
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
|
||||
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||
|
||||
# input_ids: seq_len T1, audio_codec seq_len T2
|
||||
text_last_hidden_outputs = model_outputs.hidden_states[-1]
|
||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs
|
||||
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
||||
|
||||
audio_embeddings[:, : text_input_embeds.shape[1]] += text_input_embeds
|
||||
|
||||
speech_outputs = self.codec_lm(
|
||||
attention_mask=audio_attention_mask,
|
||||
inputs_embeds=audio_embeddings,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
last_hidden_state = speech_outputs.hidden_states[-1].clone()
|
||||
|
||||
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
|
||||
audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size)
|
||||
audio_labels = audio_labels.contiguous().view(-1)
|
||||
audio_labels = audio_labels.masked_fill(
|
||||
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
|
||||
)
|
||||
codec_loss = self.loss_fct(audio_logits, audio_labels)
|
||||
audio_preds = torch.argmax(audio_logits, -1)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc = compute_accuracy(
|
||||
preds.detach()[:, :-1],
|
||||
labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
audio_acc = compute_accuracy(
|
||||
audio_preds.detach(),
|
||||
audio_labels.detach(),
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
|
||||
|
||||
return text_loss, acc, codec_loss, audio_acc
|
||||
|
||||
def decode(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
|
@ -70,7 +70,12 @@ from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
@ -135,6 +140,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Whether to unfreeze llm during training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--unfreeze-speech-projector",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to unfreeze speech adaptor during training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-speech-output",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to enable speech codec output.",
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -307,7 +325,7 @@ def compute_loss(
|
||||
)
|
||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||
# remove too long text
|
||||
texts = [ text for text in texts if len(text) < 1024 ]
|
||||
# texts = [ text for text in texts if len(text) < 1024 ]
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(
|
||||
f"Remove too long text, {messages} "
|
||||
@ -392,13 +410,22 @@ def compute_loss(
|
||||
input_ids = input_ids.type(torch.LongTensor)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
model_outputs, acc = model(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
)
|
||||
loss = model_outputs.loss
|
||||
if not params.enable_speech_output:
|
||||
loss, acc = model(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
)
|
||||
else:
|
||||
text_loss, acc, codec_loss, codec_acc = model.forward_with_speech_output(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
speech_codec_ids=answer_cosyvoice_speech_token,
|
||||
)
|
||||
loss = text_loss + codec_loss
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
@ -412,7 +439,12 @@ def compute_loss(
|
||||
info["acc"] = (
|
||||
acc * info["frames"]
|
||||
) # WAR: to avoid normalization by the number of frames
|
||||
|
||||
if params.enable_speech_output:
|
||||
info["codec_acc"] = (
|
||||
codec_acc * info["frames"]
|
||||
)
|
||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||
info["text_loss"] = text_loss.detach().cpu().item()
|
||||
return loss, info
|
||||
|
||||
|
||||
@ -429,7 +461,7 @@ def compute_validation_loss(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch.amp.autocast('cuda', enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
@ -544,7 +576,7 @@ def train_one_epoch(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch.amp.autocast('cuda', enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
@ -629,6 +661,7 @@ def run(rank, world_size, args):
|
||||
speech_encoder.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16 FIX ME
|
||||
@ -672,6 +705,16 @@ def run(rank, world_size, args):
|
||||
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
# original_tokenizer_vocab_size = len(tokenizer)
|
||||
# cosyvoice2_token_size = 6561
|
||||
# new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
||||
# "<|SPEECH_GENERATION_START|>"
|
||||
# ]
|
||||
# num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||
# model.resize_token_embeddings(len(tokenizer))
|
||||
# model.vocab_size = len(tokenizer)
|
||||
|
||||
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
@ -680,11 +723,66 @@ def run(rank, world_size, args):
|
||||
encoder_projector = EncoderProjector(
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
if not params.unfreeze_speech_projector:
|
||||
for name, param in encoder_projector.named_parameters():
|
||||
param.requires_grad = False
|
||||
encoder_projector.eval()
|
||||
|
||||
|
||||
if params.enable_speech_output:
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||
# params.llm_path_or_name,
|
||||
# attn_implementation=attn_implementation,
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 8192
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
codec_lm = Qwen2ForCausalLM(config=config)
|
||||
# cosyvoice2_token_size = 6561
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
codec_lm.print_trainable_parameters()
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
)
|
||||
|
||||
if params.pretrained_model_path:
|
||||
@ -728,6 +826,13 @@ def run(rank, world_size, args):
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
|
||||
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
||||
if codec_len > 2048:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user