add codec lm

This commit is contained in:
root 2025-04-21 01:00:06 +00:00
parent 458d697acc
commit bdb60f6ddc
3 changed files with 251 additions and 13 deletions

View File

@ -96,3 +96,22 @@ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--use-lora True --unfreeze-llm True --use-lora True --unfreeze-llm True
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: "
ngpu=2
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 40 \
--enable-musan False \
--exp-dir ./slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn False \
--use-lora True --unfreeze-llm False --enable-speech-output True
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
fi

View File

@ -58,11 +58,21 @@ class SPEECH_LLM(nn.Module):
encoder: nn.Module, encoder: nn.Module,
llm: nn.Module, llm: nn.Module,
encoder_projector: nn.Module, encoder_projector: nn.Module,
codec_lm: nn.Module = None,
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.llm = llm self.llm = llm
self.encoder_projector = encoder_projector self.encoder_projector = encoder_projector
self.codec_lm = codec_lm
if self.codec_lm:
self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size, self.codec_lm.config.hidden_size
)
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
)
self.loss_fct = torch.nn.CrossEntropyLoss()
def _merge_input_ids_with_speech_features( def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
@ -225,8 +235,112 @@ class SPEECH_LLM(nn.Module):
labels.detach()[:, 1:], labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID, ignore_label=IGNORE_TOKEN_ID,
) )
return model_outputs, acc return model_outputs.loss, acc
def forward_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
speech_codec_ids: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
# get the label start_index in inputs_embeds from labels
text_label_start_index_list = []
for i in range(labels.shape[0]):
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
text_label_start_index_list.append(text_label_start_index)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
)
text_loss = model_outputs.loss
# prepare codec lm inputs
audio_codes_lens = torch.tensor(
[len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device
)
# print(audio_codes_lens, "audio_codes_lens")
max_len_speech_codec = max(audio_codes_lens)
delay_step = 2
audio_codes = torch.full(
(inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1),
self.codec_lm.config.pad_token_id,
dtype=torch.int64,
device=input_ids.device
)
audio_labels = audio_codes.clone()
for i, speech_codec in enumerate(speech_codec_ids):
text_label_start_index = text_label_start_index_list[i]
speech_codec = torch.tensor(
speech_codec, dtype=torch.int64, device=input_ids.device
)
# print(inputs_embeds[i, text_label_start_index], "2333 test")
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
# input_ids: seq_len T1, audio_codec seq_len T2
text_last_hidden_outputs = model_outputs.hidden_states[-1]
text_input_embeds = inputs_embeds + text_last_hidden_outputs
text_input_embeds = self.speech_token_projector(text_input_embeds)
audio_embeddings[:, : text_input_embeds.shape[1]] += text_input_embeds
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,
inputs_embeds=audio_embeddings,
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = speech_outputs.hidden_states[-1].clone()
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size)
audio_labels = audio_labels.contiguous().view(-1)
audio_labels = audio_labels.masked_fill(
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
)
codec_loss = self.loss_fct(audio_logits, audio_labels)
audio_preds = torch.argmax(audio_logits, -1)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
audio_acc = compute_accuracy(
audio_preds.detach(),
audio_labels.detach(),
ignore_label=IGNORE_TOKEN_ID,
)
return text_loss, acc, codec_loss, audio_acc
def decode( def decode(
self, self,
fbank: torch.Tensor = None, fbank: torch.Tensor = None,

View File

@ -70,7 +70,12 @@ from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics from icefall import diagnostics
@ -135,6 +140,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to unfreeze llm during training.", help="Whether to unfreeze llm during training.",
) )
parser.add_argument(
"--unfreeze-speech-projector",
type=str2bool,
default=False,
help="Whether to unfreeze speech adaptor during training.",
)
parser.add_argument(
"--enable-speech-output",
type=str2bool,
default=False,
help="Whether to enable speech codec output.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -307,7 +325,7 @@ def compute_loss(
) )
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
# remove too long text # remove too long text
texts = [ text for text in texts if len(text) < 1024 ] # texts = [ text for text in texts if len(text) < 1024 ]
if len(texts) != len(messages): if len(texts) != len(messages):
logging.warning( logging.warning(
f"Remove too long text, {messages} " f"Remove too long text, {messages} "
@ -392,13 +410,22 @@ def compute_loss(
input_ids = input_ids.type(torch.LongTensor) input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
model_outputs, acc = model( if not params.enable_speech_output:
fbank=feature, loss, acc = model(
input_ids=input_ids.to(device), fbank=feature,
attention_mask=attention_mask.to(device), input_ids=input_ids.to(device),
labels=target_ids.to(device), attention_mask=attention_mask.to(device),
) labels=target_ids.to(device),
loss = model_outputs.loss )
else:
text_loss, acc, codec_loss, codec_acc = model.forward_with_speech_output(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
speech_codec_ids=answer_cosyvoice_speech_token,
)
loss = text_loss + codec_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -412,7 +439,12 @@ def compute_loss(
info["acc"] = ( info["acc"] = (
acc * info["frames"] acc * info["frames"]
) # WAR: to avoid normalization by the number of frames ) # WAR: to avoid normalization by the number of frames
if params.enable_speech_output:
info["codec_acc"] = (
codec_acc * info["frames"]
)
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info return loss, info
@ -429,7 +461,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast('cuda', enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -544,7 +576,7 @@ def train_one_epoch(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast('cuda', enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -629,6 +661,7 @@ def run(rank, world_size, args):
speech_encoder.eval() speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME # torch_dtype=torch.bfloat16 FIX ME
@ -672,6 +705,16 @@ def run(rank, world_size, args):
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict) tokenizer.add_special_tokens(special_tokens_dict)
# original_tokenizer_vocab_size = len(tokenizer)
# cosyvoice2_token_size = 6561
# new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
# "<|SPEECH_GENERATION_START|>"
# ]
# num_added_tokens = tokenizer.add_tokens(new_tokens)
# model.resize_token_embeddings(len(tokenizer))
# model.vocab_size = len(tokenizer)
llm.config.pad_token_id = tokenizer.pad_token_id llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN DEFAULT_SPEECH_TOKEN
@ -680,11 +723,66 @@ def run(rank, world_size, args):
encoder_projector = EncoderProjector( encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
) )
if not params.unfreeze_speech_projector:
for name, param in encoder_projector.named_parameters():
param.requires_grad = False
encoder_projector.eval()
if params.enable_speech_output:
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 8192
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
codec_lm = Qwen2ForCausalLM(config=config)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
codec_lm = get_peft_model(codec_lm, lora_config)
codec_lm.print_trainable_parameters()
else:
codec_lm = None
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder, speech_encoder,
llm, llm,
encoder_projector, encoder_projector,
codec_lm,
) )
if params.pretrained_model_path: if params.pretrained_model_path:
@ -728,6 +826,13 @@ def run(rank, world_size, args):
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) # )
return False return False
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
if codec_len > 2048:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
)
return False
return True return True