add debug script

This commit is contained in:
root 2025-05-08 03:37:26 -07:00
parent 37db65984c
commit bd2df570ad
4 changed files with 2447 additions and 0 deletions

View File

@ -0,0 +1,480 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from datasets import load_dataset
from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
load_manifest,
load_manifest_lazy,
)
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader
from utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=300.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--huggingface-dataset-path-or-name",
type=str,
default="/workspace/Belle_1.4M-SLAM-Omni",
help="The path or name of the Huggingface dataset",
)
group.add_argument(
"--audio-key",
type=str,
default="question_audio",
help="The key in the Huggingface dataset containing the audio data",
)
group.add_argument(
"--text-key",
type=str,
default="answer",
help="The key in the Huggingface dataset containing the text data",
)
group.add_argument(
"--resample-to-16kHz",
type=str2bool,
default=True,
help="Resample audio to 16kHz. Default: False.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
"""
Args:
cuts_valid:
CutSet for validation.
"""
logging.info("About to create dev dataset")
validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
else:
valid_sampler = SimpleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
pass
else:
return {
"test": load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
}
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
pass
else:
return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
slam_omni_zh_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_train.jsonl.gz"
)
return slam_omni_zh_cuts
# @lru_cache()
# def train_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
# )
# ultrachat_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
# )
# return CutSet.mux(
# VoiceAssistant_cuts,
# ultrachat_cuts,
# weights=[
# len(VoiceAssistant_cuts),
# len(ultrachat_cuts),
# ],
# )
# valid cuts_voice_assistant.00000.jsonl.gz
# @lru_cache()
# def valid_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get valid cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
# )
# return VoiceAssistant_cuts
# @lru_cache()
# def test_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get test cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
# )
# return VoiceAssistant_cuts
def train_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get train cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_debug.jsonl.gz"
)
return VoiceAssistant_cuts
@lru_cache()
def valid_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get valid cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_debug.jsonl.gz"
)
return VoiceAssistant_cuts
@lru_cache()
def test_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get test cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_debug.jsonl.gz"
)
return VoiceAssistant_cuts

View File

@ -0,0 +1,795 @@
from typing import List, Tuple
import torch
from torch import nn
from torchmetrics.classification import MulticlassAccuracy
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging
from utils import get_rank
class EncoderProjector(nn.Module):
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
Args:
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
llm_dim (:obj:`int`): The dimension of the language model.
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
"""
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(llm_dim, llm_dim)
def forward(self, x):
batch_size, seq_len, feat_dim = x.size()
num_frames_to_discard = seq_len % self.downsample_rate
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__(
self,
encoder: nn.Module,
llm: nn.Module,
encoder_projector: nn.Module,
codec_lm: nn.Module = None,
codec_lm_padding_side: str = "left",
):
super().__init__()
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
self.codec_lm = codec_lm
if self.codec_lm:
self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size + self.llm.config.hidden_size,
self.codec_lm.config.hidden_size,
)
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
)
self.speech_token_projector = self.speech_token_projector.to(
dtype=torch.float16
)
self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
self.loss_fct = torch.nn.CrossEntropyLoss()
self.codec_lm_padding_side = codec_lm_padding_side
self.audio_accuracy_metric = MulticlassAccuracy(
self.codec_lm.vocab_size,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=IGNORE_TOKEN_ID,
)
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
"""
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
with the speech features and padding the input_ids to the maximum length of the speech features.
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
Args:
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
input_ids (:obj:`torch.Tensor`): The input ids to merge.
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
Returns:
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
"""
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_speech_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_speech_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= speech_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
rank = get_rank()
print(f"Current rank: {rank}, input_ids: {input_ids.shape}, input_ids: {input_ids}")
print(f"Current rank: {rank}, input_embeds: {inputs_embeds.shape}, input_embeds: {inputs_embeds}")
print(f"Current rank: {rank}, attention_mask: {attention_mask.shape}, attention_mask: {attention_mask}")
print(f"Current rank: {rank}, labels: {labels.shape}, labels: {labels}")
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
)
print(f"Current rank: {rank}, model_outputs: {model_outputs}")
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return model_outputs.loss, acc
def forward_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
speech_codec_ids: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
input_seq_len = attention_mask.sum(dim=1) # shape, B
(
text_label_start_index_list,
text_input_start_index_list,
input_question_len_list,
) = ([], [], [])
for i in range(labels.shape[0]):
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
input_embeds_start_index = input_embeds_valid_index[0]
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
text_labels_start_index = text_labels_valid_index[0]
assert (
input_seq_len[i]
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert (
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index
assert (
input_question_len
+ text_labels_valid_index[-1]
- text_labels_start_index
+ 1
== input_seq_len[i]
)
text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len)
rank = get_rank()
print(f"Current rank: {rank}, input_ids: {input_ids.shape}, input_ids: {input_ids}")
print(f"Current rank: {rank}, input_embeds: {inputs_embeds.shape}, input_embeds: {inputs_embeds}")
print(f"Current rank: {rank}, attention_mask: {attention_mask.shape}, attention_mask: {attention_mask}")
print(f"Current rank: {rank}, labels: {labels.shape}, labels: {labels}")
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
)
print(f"Current rank: {rank}, model_outputs: {model_outputs}")
text_loss = model_outputs.loss
delay_step = 1
# prepare codec lm inputs
audio_codes_lens = [
len(x) + input_question_len_list[i] + delay_step + 1
for i, x in enumerate(speech_codec_ids)
]
max_len_speech_codec = max(audio_codes_lens)
if self.codec_lm_padding_side == "right":
audio_codes = [
[self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
elif self.codec_lm_padding_side == "left":
audio_codes = [
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids)
]
audio_codes = torch.tensor(
audio_codes, dtype=torch.int64, device=input_ids.device
)
audio_labels = torch.tensor(
audio_labels, dtype=torch.int64, device=input_ids.device
)
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][
i,
text_input_start_index_list[i] : text_input_start_index_list[i]
+ input_seq_len[i]
- 1,
]
print(233336666666, text_last_hidden, text_last_hidden.shape)
text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[
i,
text_input_start_index_list[i]
+ 1 : text_input_start_index_list[i]
+ input_seq_len[i],
] # exclude bos
text_embeds_list.append(text_embed)
text_input_embeds = torch.cat(
[
text_last_hidden,
text_embed,
],
dim=-1,
) # shape, T, D1 + D2
text_input_embeds = self.speech_token_projector(
text_input_embeds
) # shape, T, D_codec
text_input_embeds_list.append(text_input_embeds)
for i in range(audio_embeddings.shape[0]):
text_input_embeds = text_input_embeds_list[i]
if self.codec_lm_padding_side == "right":
audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
elif self.codec_lm_padding_side == "left":
start_idx = torch.where(
audio_codes[i] == self.codec_lm.config.mask_token_id
)[0][0]
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
assert (
start_idx == start_idx_re_compute
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx
]
logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
)
audio_embeddings[
i, start_idx : start_idx + text_input_embeds.shape[0]
] += text_input_embeds
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,
inputs_embeds=audio_embeddings,
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = speech_outputs.hidden_states[-1].clone()
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
audio_logits = audio_logits.contiguous().view(
-1, self.codec_lm.config.vocab_size
)
audio_labels = audio_labels.contiguous().view(-1)
audio_labels = audio_labels.masked_fill(
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
)
codec_loss = self.loss_fct(audio_logits, audio_labels)
audio_preds = torch.argmax(audio_logits, -1)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
print(23333444444, preds)
print(233335555555, labels)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
audio_acc = compute_accuracy(
audio_preds.detach(),
audio_labels.detach(),
ignore_label=IGNORE_TOKEN_ID,
)
audio_topk_acc = self.audio_accuracy_metric(
audio_logits.detach(), audio_labels.detach()
).item()
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
def decode(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_new_tokens", 1024),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", True),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 0.5),
top_k=kwargs.get("top_k", 20),
repetition_penalty=kwargs.get("repetition_penalty", 1.1),
temperature=kwargs.get("temperature", 0.7),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
def decode_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 2048, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]:
"""
Generates text and corresponding speech tokens using the revised logic.
Args:
fbank: Input audio features.
input_ids: Input token IDs for the text prompt.
attention_mask: Attention mask for the text prompt.
max_text_new_tokens: Max new tokens for text generation.
max_speech_new_tokens: Max new tokens for speech generation.
llm_kwargs: Additional arguments for self.llm.generate.
codec_lm_kwargs: Additional arguments for self.codec_lm.generate.
Returns:
Tuple[torch.LongTensor, List[List[int]]]:
- generated_text_ids: Tensor of generated text token IDs (including prompt).
- generated_speech_tokens: List of lists, where each inner list contains
the generated speech codec tokens for a batch item.
"""
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
if (
not self.codec_lm
or not self.speech_token_projector
or not self.codec_lm_head
):
raise ValueError(
"codec_lm and associated layers must be initialized to generate speech output."
)
device = next(self.parameters()).device # Use model's device
batch_size = fbank.shape[0]
# --- 1. Prepare Prompt Embeddings ---
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
# Merge speech features with prompt embeddings
(
merged_prompt_inputs_embeds,
merged_prompt_attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, prompt_embeds, input_ids, attention_mask
)
# --- 2. Generate Text using LLM ---
# Use merged embeds/mask as input to generate
# Ensure kwargs passed are suitable for llm.generate
# Note: Using default generation params from `decode` if not provided in kwargs
final_llm_kwargs = {
"bos_token_id": self.llm.config.bos_token_id,
"eos_token_id": self.llm.config.eos_token_id,
"pad_token_id": self.llm.config.pad_token_id,
"num_beams": 1,
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"top_p": 0.5,
"top_k": 20,
"repetition_penalty": 1.1,
"temperature": 0.7,
**(llm_kwargs or {}), # User-provided kwargs override defaults
}
text_outputs = self.llm.generate(
inputs_embeds=merged_prompt_inputs_embeds,
attention_mask=merged_prompt_attention_mask,
max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True,
output_hidden_states=True,
**final_llm_kwargs,
)
delay_step = 1
generated_text_ids = text_outputs.sequences # [B, S_full]
eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(
torch.tensor([[eos_token_id]], device=device)
)
assert (
generated_text_ids[0, -1] == eos_token_id
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [
token_hidden_states[0].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
thinker_token_embeds_org[1],
],
dim=1,
)
thinker_token_embeds = (
[first_thinker_token_embed]
+ thinker_token_embeds_org[2:]
+ [eos_token_embedding]
)
thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
thinker_reply_part = [
torch.cat(
[
thinker_hidden_state,
thinker_token_embed,
],
dim=-1,
)
for thinker_hidden_state, thinker_token_embed in zip(
thinker_hidden_states[1:], thinker_token_embeds[1:]
)
]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
thinker_prompt_part = torch.cat(
[
thinker_hidden_states[0],
thinker_token_embeds[0],
],
dim=-1,
)
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
thinker_reply_part = self.speech_token_projector(thinker_reply_part)
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step + 1),
self.codec_lm.config.mask_token_id,
dtype=torch.long,
device=self.llm.device,
)
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
thinker_reply_part[:, : delay_step + 1, :],
],
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
past_key_values = None
generated_speech_tokens_list = []
next_token_ids = None
for t in range(max_speech_new_tokens):
if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
next_token_ids
)
if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :]
thinker_reply_part = thinker_reply_part[:, 1:, :]
codec_outputs = self.codec_lm(
inputs_embeds=talker_inputs_embeds,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
output_hidden_states=True,
)
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
next_token_logits = self.codec_lm_head(last_token_hidden_state)
next_token_ids = topk_sampling(
next_token_logits,
)
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break
past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(
next_token_ids.squeeze(1).cpu().tolist()[0]
)
return generated_text_ids, generated_speech_tokens_list
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float()
def topk_sampling(
logits,
top_k=50,
top_p=0.95,
temperature=0.8,
):
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits_filtered = top_k_top_p_filtering(
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
)
# Sample
probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
tokens = torch.multinomial(probs, num_samples=1)
return tokens
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits

View File

@ -0,0 +1,195 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
set -eou pipefail
stage=$1
stop_stage=$2
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Clone CosyVoice repo and install requirements inside the container"
# docker: ghcr.io/swivid/f5-tts:main
pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
cd /workspace/CosyVoice
# If you failed to clone submodule due to network failures, please run following command until success
git submodule update --init --recursive
pip install -r qwen_omni/requirements.txt
pip install -r qwen_omni/requirements-cosyvoice.txt
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
pip install sherpa-onnx
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Compute fbank feature from huggingface"
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_test \
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix belle
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Combine features"
manifest_dir=data/fbank
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
# exclude cust_belle_00000.jsonl.gz for valid and test set
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
echo $pieces | wc
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
fi
fi
ngpu=8
exp_dir=./qwen_omni/exp_speech2speech
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Decoding, only support batch_size=1 for now."
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: Gradio Demo"
python3 ./qwen_omni/web_demo.py \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-999.pt \
--use-flash-attn True \
--enable-speech-output True \
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 1: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_voice_assistant \
# --huggingface-dataset-path-or-name worstchan/VoiceAssistant-400K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix voice_assistant
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_voice_assistant_cosy2 \
--json-file-path /workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json \
--prefix voice_assistant
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "stage 7: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_ultrachat \
# --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix ultrachat
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_ultrachat_cosy2 \
--json-file-path /workspace/slam/UltraChat-vocalnet/UltraChat.json \
--prefix ultrachat
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "stage 8: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset test --split test \
--audio-key audio --text-key text \
--prefix gigaspeech
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "stage 9: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset xl --split train \
--audio-key audio --text-key text \
--prefix gigaspeech
fi
ngpu=2
exp_dir=./qwen_omni/exp_speech2speech_en
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "stage 10: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 1 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--dataset-format vocalnet \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn False --bucketing-sampler False \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi

View File

@ -0,0 +1,977 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
import whisper
from data_module import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# from icefall import diagnostics
from utils import get_rank, get_world_size
# from icefall.env import get_env_info
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
help="Path or name of the large language model.",
)
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=False,
help="Whether to use lora to fine-tune llm.",
)
parser.add_argument(
"--enable-speech-output",
type=str2bool,
default=False,
help="Whether to enable speech codec output.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="whisper_qwen/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--pretrained-model-path",
type=str,
default=None,
help="""The path to the pretrained model if it is not None. Training will
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
""",
)
parser.add_argument(
"--sampler-state-dict-path",
type=str,
default=None,
help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict.
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=True,
help="Whether to use half precision training.",
)
parser.add_argument(
"--unfreeze-llm",
type=str2bool,
default=False,
help="Whether to unfreeze llm during training.",
)
parser.add_argument(
"--unfreeze-speech-projector",
type=str2bool,
default=False,
help="Whether to unfreeze speech adaptor during training.",
)
parser.add_argument(
"--dataset-format",
type=str,
default="slam_omni",
help="The format of the dataset.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- frame_shift_ms: The frame shift in milliseconds.
- allowed_excess_duration_ratio: The allowed excess duration ratio.
- best_train_loss: The best training loss so far.
- best_valid_loss: The best validation loss so far.
- best_train_epoch: The epoch where the best training loss is achieved.
- best_valid_epoch: The epoch where the best validation loss is achieved.
- batch_idx_train: The batch index of the current batch.
- log_interval: Log training stats every `log_interval` batches.
- reset_interval: Reset the stats every `reset_interval` batches.
- valid_interval: Run validation every `valid_interval` batches.
- env_info: The environment information.
"""
params = AttributeDict(
{
"allowed_excess_duration_ratio": 0.1,
"subsampling_factor": 2,
"frame_shift_ms": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 5000,
# "env_info": get_env_info(),
}
)
return params
def process_batch_slam_omni(batch: dict):
answers = batch["supervisions"]["text"]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
messages = []
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()},
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
def process_batch_vocalnet(batch: dict):
answers = batch["supervisions"]["text"]
answer_cosyvoice_speech_token = [
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
]
messages = []
for i in range(len(answers)):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute the loss for the given batch.
Args:
params:
It is returned by :func:`get_params`.
tokenizer:
The tokenizer used to encode the text.
model:
The model for training.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
Whether it is training.
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
# WAR: TODO FIXME check qwen3
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
# max_frames = params.max_duration * 1000 // params.frame_shift_ms
# allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
# batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
if params.dataset_format == "slam_omni":
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
elif params.dataset_format == "vocalnet":
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
print(f"messages: {messages}")
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training):
if not params.enable_speech_output:
loss, acc = model(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
)
else:
(
text_loss,
acc,
codec_loss,
codec_acc,
codec_topk_acc,
) = model.forward_with_speech_output(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
speech_codec_ids=answer_cosyvoice_speech_token,
)
loss = text_loss + codec_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
feature_lens = batch["supervisions"]["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
if params.enable_speech_output:
info["codec_acc"] = codec_acc * info["frames"]
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
exit()
return tot_loss
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.encoder_projector.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
tokenizer=tokenizer,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
except: # noqa
display_and_save_batch(batch, params=params)
raise
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]
except: # noqa
cur_lr = 0.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if not params.unfreeze_speech_projector:
for name, param in encoder_projector.named_parameters():
param.requires_grad = False
encoder_projector.eval()
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
if params.dataset_format == "slam_omni":
codec_vocab_size = 4096 + 4
elif params.dataset_format == "vocalnet":
codec_vocab_size = 6561 + 4
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
# TODO: modify above vocab size or supress_tokens when decoding
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
else:
codec_lm = None
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
logging.info("Trainable parameters (excluding model.eval modules):")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f"{name}: {param.shape}")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
)
data_module = AsrDataModule(args)
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
codec_len = (
len(c.custom["answer_cosyvoice_speech_token"])
if "answer_cosyvoice_speech_token" in c.custom
else len(c.custom["speech_token"])
)
if codec_len > 2200:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
)
return False
return True
if params.dataset_format == "slam_omni":
train_cuts = data_module.train_cuts()
valid_cuts = data_module.dev_cuts()
elif params.dataset_format == "vocalnet":
train_cuts = data_module.train_cuts_en_vocalnet()
valid_cuts = data_module.valid_cuts_en_vocalnet()
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
train_cuts = train_cuts.filter(remove_short_and_long_utt)
valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
tokenizer=tokenizer,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
logging.info("Done!")
def display_and_save_batch(
batch: dict,
params: AttributeDict,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()