refactor code

This commit is contained in:
root 2025-05-22 19:14:52 -07:00
parent 7a12d88d6c
commit 9fff18edec
11 changed files with 141 additions and 2874 deletions

View File

@ -1,480 +0,0 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from datasets import load_dataset
from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
load_manifest,
load_manifest_lazy,
)
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader
from utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=300.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--huggingface-dataset-path-or-name",
type=str,
default="/workspace/Belle_1.4M-SLAM-Omni",
help="The path or name of the Huggingface dataset",
)
group.add_argument(
"--audio-key",
type=str,
default="question_audio",
help="The key in the Huggingface dataset containing the audio data",
)
group.add_argument(
"--text-key",
type=str,
default="answer",
help="The key in the Huggingface dataset containing the text data",
)
group.add_argument(
"--resample-to-16kHz",
type=str2bool,
default=True,
help="Resample audio to 16kHz. Default: False.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
"""
Args:
cuts_valid:
CutSet for validation.
"""
logging.info("About to create dev dataset")
validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
else:
valid_sampler = SimpleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
pass
else:
return {
"test": load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
}
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
pass
else:
return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
slam_omni_zh_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_train.jsonl.gz"
)
return slam_omni_zh_cuts
# @lru_cache()
# def train_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
# )
# ultrachat_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
# )
# return CutSet.mux(
# VoiceAssistant_cuts,
# ultrachat_cuts,
# weights=[
# len(VoiceAssistant_cuts),
# len(ultrachat_cuts),
# ],
# )
# valid cuts_voice_assistant.00000.jsonl.gz
# @lru_cache()
# def valid_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get valid cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
# )
# return VoiceAssistant_cuts
# @lru_cache()
# def test_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get test cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
# )
# return VoiceAssistant_cuts
def train_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get train cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_debug.jsonl.gz"
)
return VoiceAssistant_cuts
@lru_cache()
def valid_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get valid cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_debug.jsonl.gz"
)
return VoiceAssistant_cuts
@lru_cache()
def test_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get test cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_debug.jsonl.gz"
)
return VoiceAssistant_cuts

View File

@ -1,795 +0,0 @@
from typing import List, Tuple
import torch
from torch import nn
from torchmetrics.classification import MulticlassAccuracy
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging
from utils import get_rank
class EncoderProjector(nn.Module):
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
Args:
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
llm_dim (:obj:`int`): The dimension of the language model.
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
"""
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(llm_dim, llm_dim)
def forward(self, x):
batch_size, seq_len, feat_dim = x.size()
num_frames_to_discard = seq_len % self.downsample_rate
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__(
self,
encoder: nn.Module,
llm: nn.Module,
encoder_projector: nn.Module,
codec_lm: nn.Module = None,
codec_lm_padding_side: str = "left",
):
super().__init__()
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
self.codec_lm = codec_lm
if self.codec_lm:
self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size + self.llm.config.hidden_size,
self.codec_lm.config.hidden_size,
)
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
)
self.speech_token_projector = self.speech_token_projector.to(
dtype=torch.float16
)
self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
self.loss_fct = torch.nn.CrossEntropyLoss()
self.codec_lm_padding_side = codec_lm_padding_side
self.audio_accuracy_metric = MulticlassAccuracy(
self.codec_lm.vocab_size,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=IGNORE_TOKEN_ID,
)
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
"""
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
with the speech features and padding the input_ids to the maximum length of the speech features.
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
Args:
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
input_ids (:obj:`torch.Tensor`): The input ids to merge.
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
Returns:
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
"""
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_speech_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_speech_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= speech_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
rank = get_rank()
print(f"Current rank: {rank}, input_ids: {input_ids.shape}, input_ids: {input_ids}")
print(f"Current rank: {rank}, input_embeds: {inputs_embeds.shape}, input_embeds: {inputs_embeds}")
print(f"Current rank: {rank}, attention_mask: {attention_mask.shape}, attention_mask: {attention_mask}")
print(f"Current rank: {rank}, labels: {labels.shape}, labels: {labels}")
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
)
print(f"Current rank: {rank}, model_outputs: {model_outputs}")
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return model_outputs.loss, acc
def forward_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
speech_codec_ids: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
input_seq_len = attention_mask.sum(dim=1) # shape, B
(
text_label_start_index_list,
text_input_start_index_list,
input_question_len_list,
) = ([], [], [])
for i in range(labels.shape[0]):
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
input_embeds_start_index = input_embeds_valid_index[0]
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
text_labels_start_index = text_labels_valid_index[0]
assert (
input_seq_len[i]
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert (
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index
assert (
input_question_len
+ text_labels_valid_index[-1]
- text_labels_start_index
+ 1
== input_seq_len[i]
)
text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len)
rank = get_rank()
print(f"Current rank: {rank}, input_ids: {input_ids.shape}, input_ids: {input_ids}")
print(f"Current rank: {rank}, input_embeds: {inputs_embeds.shape}, input_embeds: {inputs_embeds}")
print(f"Current rank: {rank}, attention_mask: {attention_mask.shape}, attention_mask: {attention_mask}")
print(f"Current rank: {rank}, labels: {labels.shape}, labels: {labels}")
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
)
print(f"Current rank: {rank}, model_outputs: {model_outputs}")
text_loss = model_outputs.loss
delay_step = 1
# prepare codec lm inputs
audio_codes_lens = [
len(x) + input_question_len_list[i] + delay_step + 1
for i, x in enumerate(speech_codec_ids)
]
max_len_speech_codec = max(audio_codes_lens)
if self.codec_lm_padding_side == "right":
audio_codes = [
[self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
elif self.codec_lm_padding_side == "left":
audio_codes = [
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids)
]
audio_codes = torch.tensor(
audio_codes, dtype=torch.int64, device=input_ids.device
)
audio_labels = torch.tensor(
audio_labels, dtype=torch.int64, device=input_ids.device
)
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][
i,
text_input_start_index_list[i] : text_input_start_index_list[i]
+ input_seq_len[i]
- 1,
]
print(233336666666, text_last_hidden, text_last_hidden.shape)
text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[
i,
text_input_start_index_list[i]
+ 1 : text_input_start_index_list[i]
+ input_seq_len[i],
] # exclude bos
text_embeds_list.append(text_embed)
text_input_embeds = torch.cat(
[
text_last_hidden,
text_embed,
],
dim=-1,
) # shape, T, D1 + D2
text_input_embeds = self.speech_token_projector(
text_input_embeds
) # shape, T, D_codec
text_input_embeds_list.append(text_input_embeds)
for i in range(audio_embeddings.shape[0]):
text_input_embeds = text_input_embeds_list[i]
if self.codec_lm_padding_side == "right":
audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
elif self.codec_lm_padding_side == "left":
start_idx = torch.where(
audio_codes[i] == self.codec_lm.config.mask_token_id
)[0][0]
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
assert (
start_idx == start_idx_re_compute
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx
]
logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
)
audio_embeddings[
i, start_idx : start_idx + text_input_embeds.shape[0]
] += text_input_embeds
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,
inputs_embeds=audio_embeddings,
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = speech_outputs.hidden_states[-1].clone()
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
audio_logits = audio_logits.contiguous().view(
-1, self.codec_lm.config.vocab_size
)
audio_labels = audio_labels.contiguous().view(-1)
audio_labels = audio_labels.masked_fill(
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
)
codec_loss = self.loss_fct(audio_logits, audio_labels)
audio_preds = torch.argmax(audio_logits, -1)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
print(23333444444, preds)
print(233335555555, labels)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
audio_acc = compute_accuracy(
audio_preds.detach(),
audio_labels.detach(),
ignore_label=IGNORE_TOKEN_ID,
)
audio_topk_acc = self.audio_accuracy_metric(
audio_logits.detach(), audio_labels.detach()
).item()
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
def decode(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_new_tokens", 1024),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", True),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 0.5),
top_k=kwargs.get("top_k", 20),
repetition_penalty=kwargs.get("repetition_penalty", 1.1),
temperature=kwargs.get("temperature", 0.7),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
def decode_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 2048, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]:
"""
Generates text and corresponding speech tokens using the revised logic.
Args:
fbank: Input audio features.
input_ids: Input token IDs for the text prompt.
attention_mask: Attention mask for the text prompt.
max_text_new_tokens: Max new tokens for text generation.
max_speech_new_tokens: Max new tokens for speech generation.
llm_kwargs: Additional arguments for self.llm.generate.
codec_lm_kwargs: Additional arguments for self.codec_lm.generate.
Returns:
Tuple[torch.LongTensor, List[List[int]]]:
- generated_text_ids: Tensor of generated text token IDs (including prompt).
- generated_speech_tokens: List of lists, where each inner list contains
the generated speech codec tokens for a batch item.
"""
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
if (
not self.codec_lm
or not self.speech_token_projector
or not self.codec_lm_head
):
raise ValueError(
"codec_lm and associated layers must be initialized to generate speech output."
)
device = next(self.parameters()).device # Use model's device
batch_size = fbank.shape[0]
# --- 1. Prepare Prompt Embeddings ---
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
# Merge speech features with prompt embeddings
(
merged_prompt_inputs_embeds,
merged_prompt_attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, prompt_embeds, input_ids, attention_mask
)
# --- 2. Generate Text using LLM ---
# Use merged embeds/mask as input to generate
# Ensure kwargs passed are suitable for llm.generate
# Note: Using default generation params from `decode` if not provided in kwargs
final_llm_kwargs = {
"bos_token_id": self.llm.config.bos_token_id,
"eos_token_id": self.llm.config.eos_token_id,
"pad_token_id": self.llm.config.pad_token_id,
"num_beams": 1,
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"top_p": 0.5,
"top_k": 20,
"repetition_penalty": 1.1,
"temperature": 0.7,
**(llm_kwargs or {}), # User-provided kwargs override defaults
}
text_outputs = self.llm.generate(
inputs_embeds=merged_prompt_inputs_embeds,
attention_mask=merged_prompt_attention_mask,
max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True,
output_hidden_states=True,
**final_llm_kwargs,
)
delay_step = 1
generated_text_ids = text_outputs.sequences # [B, S_full]
eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(
torch.tensor([[eos_token_id]], device=device)
)
assert (
generated_text_ids[0, -1] == eos_token_id
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [
token_hidden_states[0].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
thinker_token_embeds_org[1],
],
dim=1,
)
thinker_token_embeds = (
[first_thinker_token_embed]
+ thinker_token_embeds_org[2:]
+ [eos_token_embedding]
)
thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
thinker_reply_part = [
torch.cat(
[
thinker_hidden_state,
thinker_token_embed,
],
dim=-1,
)
for thinker_hidden_state, thinker_token_embed in zip(
thinker_hidden_states[1:], thinker_token_embeds[1:]
)
]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
thinker_prompt_part = torch.cat(
[
thinker_hidden_states[0],
thinker_token_embeds[0],
],
dim=-1,
)
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
thinker_reply_part = self.speech_token_projector(thinker_reply_part)
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step + 1),
self.codec_lm.config.mask_token_id,
dtype=torch.long,
device=self.llm.device,
)
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
thinker_reply_part[:, : delay_step + 1, :],
],
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
past_key_values = None
generated_speech_tokens_list = []
next_token_ids = None
for t in range(max_speech_new_tokens):
if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
next_token_ids
)
if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :]
thinker_reply_part = thinker_reply_part[:, 1:, :]
codec_outputs = self.codec_lm(
inputs_embeds=talker_inputs_embeds,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
output_hidden_states=True,
)
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
next_token_logits = self.codec_lm_head(last_token_hidden_state)
next_token_ids = topk_sampling(
next_token_logits,
)
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break
past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(
next_token_ids.squeeze(1).cpu().tolist()[0]
)
return generated_text_ids, generated_speech_tokens_list
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float()
def topk_sampling(
logits,
top_k=50,
top_p=0.95,
temperature=0.8,
):
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits_filtered = top_k_top_p_filtering(
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
)
# Sample
probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
tokens = torch.multinomial(probs, num_samples=1)
return tokens
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits

View File

@ -1,195 +0,0 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
set -eou pipefail
stage=$1
stop_stage=$2
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Clone CosyVoice repo and install requirements inside the container"
# docker: ghcr.io/swivid/f5-tts:main
pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
cd /workspace/CosyVoice
# If you failed to clone submodule due to network failures, please run following command until success
git submodule update --init --recursive
pip install -r qwen_omni/requirements.txt
pip install -r qwen_omni/requirements-cosyvoice.txt
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
pip install sherpa-onnx
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Compute fbank feature from huggingface"
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_test \
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix belle
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Combine features"
manifest_dir=data/fbank
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
# exclude cust_belle_00000.jsonl.gz for valid and test set
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
echo $pieces | wc
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
fi
fi
ngpu=8
exp_dir=./qwen_omni/exp_speech2speech
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Decoding, only support batch_size=1 for now."
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: Gradio Demo"
python3 ./qwen_omni/web_demo.py \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-999.pt \
--use-flash-attn True \
--enable-speech-output True \
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 1: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_voice_assistant \
# --huggingface-dataset-path-or-name worstchan/VoiceAssistant-400K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix voice_assistant
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_voice_assistant_cosy2 \
--json-file-path /workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json \
--prefix voice_assistant
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "stage 7: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_ultrachat \
# --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix ultrachat
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_ultrachat_cosy2 \
--json-file-path /workspace/slam/UltraChat-vocalnet/UltraChat.json \
--prefix ultrachat
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "stage 8: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset test --split test \
--audio-key audio --text-key text \
--prefix gigaspeech
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "stage 9: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset xl --split train \
--audio-key audio --text-key text \
--prefix gigaspeech
fi
ngpu=2
exp_dir=./qwen_omni/exp_speech2speech_en
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "stage 10: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 1 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--dataset-format vocalnet \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn False --bucketing-sampler False \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi

View File

@ -1,977 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
import whisper
from data_module import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# from icefall import diagnostics
from utils import get_rank, get_world_size
# from icefall.env import get_env_info
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
help="Path or name of the large language model.",
)
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=False,
help="Whether to use lora to fine-tune llm.",
)
parser.add_argument(
"--enable-speech-output",
type=str2bool,
default=False,
help="Whether to enable speech codec output.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="whisper_qwen/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--pretrained-model-path",
type=str,
default=None,
help="""The path to the pretrained model if it is not None. Training will
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
""",
)
parser.add_argument(
"--sampler-state-dict-path",
type=str,
default=None,
help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict.
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=True,
help="Whether to use half precision training.",
)
parser.add_argument(
"--unfreeze-llm",
type=str2bool,
default=False,
help="Whether to unfreeze llm during training.",
)
parser.add_argument(
"--unfreeze-speech-projector",
type=str2bool,
default=False,
help="Whether to unfreeze speech adaptor during training.",
)
parser.add_argument(
"--dataset-format",
type=str,
default="slam_omni",
help="The format of the dataset.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- frame_shift_ms: The frame shift in milliseconds.
- allowed_excess_duration_ratio: The allowed excess duration ratio.
- best_train_loss: The best training loss so far.
- best_valid_loss: The best validation loss so far.
- best_train_epoch: The epoch where the best training loss is achieved.
- best_valid_epoch: The epoch where the best validation loss is achieved.
- batch_idx_train: The batch index of the current batch.
- log_interval: Log training stats every `log_interval` batches.
- reset_interval: Reset the stats every `reset_interval` batches.
- valid_interval: Run validation every `valid_interval` batches.
- env_info: The environment information.
"""
params = AttributeDict(
{
"allowed_excess_duration_ratio": 0.1,
"subsampling_factor": 2,
"frame_shift_ms": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 5000,
# "env_info": get_env_info(),
}
)
return params
def process_batch_slam_omni(batch: dict):
answers = batch["supervisions"]["text"]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
messages = []
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()},
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
def process_batch_vocalnet(batch: dict):
answers = batch["supervisions"]["text"]
answer_cosyvoice_speech_token = [
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
]
messages = []
for i in range(len(answers)):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute the loss for the given batch.
Args:
params:
It is returned by :func:`get_params`.
tokenizer:
The tokenizer used to encode the text.
model:
The model for training.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
Whether it is training.
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
# WAR: TODO FIXME check qwen3
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
# max_frames = params.max_duration * 1000 // params.frame_shift_ms
# allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
# batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
if params.dataset_format == "slam_omni":
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
elif params.dataset_format == "vocalnet":
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
print(f"messages: {messages}")
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training):
if not params.enable_speech_output:
loss, acc = model(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
)
else:
(
text_loss,
acc,
codec_loss,
codec_acc,
codec_topk_acc,
) = model.forward_with_speech_output(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
speech_codec_ids=answer_cosyvoice_speech_token,
)
loss = text_loss + codec_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
feature_lens = batch["supervisions"]["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
if params.enable_speech_output:
info["codec_acc"] = codec_acc * info["frames"]
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
exit()
return tot_loss
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.encoder_projector.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
tokenizer=tokenizer,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
except: # noqa
display_and_save_batch(batch, params=params)
raise
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]
except: # noqa
cur_lr = 0.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if not params.unfreeze_speech_projector:
for name, param in encoder_projector.named_parameters():
param.requires_grad = False
encoder_projector.eval()
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
if params.dataset_format == "slam_omni":
codec_vocab_size = 4096 + 4
elif params.dataset_format == "vocalnet":
codec_vocab_size = 6561 + 4
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
# TODO: modify above vocab size or supress_tokens when decoding
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
else:
codec_lm = None
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
logging.info("Trainable parameters (excluding model.eval modules):")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f"{name}: {param.shape}")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
)
data_module = AsrDataModule(args)
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
codec_len = (
len(c.custom["answer_cosyvoice_speech_token"])
if "answer_cosyvoice_speech_token" in c.custom
else len(c.custom["speech_token"])
)
if codec_len > 2200:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
)
return False
return True
if params.dataset_format == "slam_omni":
train_cuts = data_module.train_cuts()
valid_cuts = data_module.dev_cuts()
elif params.dataset_format == "vocalnet":
train_cuts = data_module.train_cuts_en_vocalnet()
valid_cuts = data_module.valid_cuts_en_vocalnet()
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
train_cuts = train_cuts.filter(remove_short_and_long_utt)
valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
tokenizer=tokenizer,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
logging.info("Done!")
def display_and_save_batch(
batch: dict,
params: AttributeDict,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()

View File

@ -240,11 +240,11 @@ fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "stage 14: Client"
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
# The final assignment of datasets in the original script is used here:
# (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
declare -a target_datasets=("openbookqa" "ifeval" "sd-qa" "commoneval" "alpacaeval_full")
declare -a target_datasets=("alpacaeval_full" "wildvoice" "advbench" "bbh" "mmsu")
NUM_CLIENT_JOBS=4 # Number of parallel client jobs
BASE_PORT=8000 # Base port for servers
@ -365,7 +365,8 @@ if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
# pip install gradio sherpa-onnx
log "stage 17: Server for adapter only speech continuation"
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
# exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_asr_second_ce
exp_dir=./qwen_omni/exp_speech2text_first_multi_en_continuation_second_qa
N_GPUS=4 # Define the number of GPUs/processes you want to launch

View File

@ -36,6 +36,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PerturbSpeed,
PrecomputedFeatures,
SimpleCutSampler,
@ -46,7 +47,6 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader
from utils import get_local_rank, str2bool
@ -203,21 +203,15 @@ class AsrDataModule:
group.add_argument(
"--audio-key",
type=str,
default="audio",
default=None,
help="The key in the Huggingface dataset containing the audio data",
)
group.add_argument(
"--text-key",
type=str,
default="text",
default=None,
help="The key in the Huggingface dataset containing the text data",
)
# group.add_argument(
# "--resample-to-16kHz",
# type=str2bool,
# default=True,
# help="Resample audio to 16kHz. Default: False.",
# )
def train_dataloaders(
self,
@ -389,29 +383,21 @@ class AsrDataModule:
return test_dl
@lru_cache()
def test_cuts(self) -> CutSet:
def test_cuts_belle(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
pass
else:
return {
"test": load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
}
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
pass
else:
return load_manifest_lazy(
return {
"test": load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
}
@lru_cache()
def train_cuts(self) -> CutSet:
def dev_cuts_belle(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
@lru_cache()
def train_cuts_belle(self) -> CutSet:
logging.info("About to get train cuts")
slam_omni_zh_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_train.jsonl.gz"
@ -435,8 +421,6 @@ class AsrDataModule:
len(ultrachat_cuts),
],
)
# valid cuts_voice_assistant.00000.jsonl.gz
@lru_cache()
def valid_cuts_en_vocalnet(self) -> CutSet:
logging.info("About to get valid cuts")
@ -453,15 +437,6 @@ class AsrDataModule:
)
return {"test": VoiceAssistant_cuts}
def test_cuts_voicebench(
self,
) -> CutSet:
logging.info("About to get test cuts")
VoiceAssistant_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
)
return {"test": VoiceAssistant_cuts}
@lru_cache()
def train_cuts_ultravox(self) -> CutSet:
logging.info("About to get train cuts")
@ -556,65 +531,6 @@ class AsrDataModule:
],
)
# @lru_cache()
# def train_cuts_ultravox(self) -> CutSet:
# logging.info("About to get train cuts")
# keep_columns = ["audio", "text", "continuation", "id"]
# librispeech_path="fixie-ai/librispeech_asr"
# # 148_688
# librispeech_other = load_dataset(librispeech_path, 'other', split='train.500', streaming=True)
# # 104_014
# librispeech_clean_360 = load_dataset(librispeech_path, 'clean', split='train.360', streaming=True)
# # 28_539
# librispeech_clean_100 = load_dataset(librispeech_path, 'clean', split='train.100', streaming=True)
# cols_to_remove = librispeech_clean_100.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# librispeech_clean_100 = librispeech_clean_100.remove_columns(cols_to_remove)
# librispeech_clean_360 = librispeech_clean_360.remove_columns(cols_to_remove)
# librispeech_other = librispeech_other.remove_columns(cols_to_remove)
# people_speech_path="fixie-ai/peoples_speech"
# # 1_501_271
# people_speech_clean = load_dataset(people_speech_path, 'clean', split='train', streaming=True)
# # 548_000
# people_speech_dirty_sa = load_dataset(people_speech_path, 'dirty_sa', split='train', streaming=True)
# cols_to_remove = people_speech_clean.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# people_speech_clean = people_speech_clean.remove_columns(cols_to_remove)
# people_speech_dirty_sa = people_speech_dirty_sa.remove_columns(cols_to_remove)
# # 8_266_422
# gigaspeech_path="fixie-ai/gigaspeech"
# gigaspeech = load_dataset(gigaspeech_path, 'xl-empty-audio-removed', split='train', streaming=True)
# # first rename segment_id to id
# gigaspeech = gigaspeech.rename_column("segment_id", "id")
# cols_to_remove = gigaspeech.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# gigaspeech = gigaspeech.remove_columns(cols_to_remove)
# total_item = 104014 + 28539 + 8266422 + 1501271 + 548000 + 148688
# final_datasets = interleave_datasets([
# librispeech_clean_100,
# librispeech_clean_360,
# gigaspeech,
# people_speech_clean,
# people_speech_dirty_sa,
# librispeech_other,
# ], probabilities=[
# 28539 / total_item,
# 104014 / total_item,
# 8266422 / total_item,
# 1501271 / total_item,
# 548000 / total_item,
# 148688 / total_item,
# ])
# train_cuts = CutSet.from_huggingface_dataset(
# final_datasets, audio_key=self.args.audio_key, text_key=self.args.text_key
# )
# return train_cuts
@lru_cache()
def valid_cuts_ultravox(self) -> CutSet:
logging.info("About to get valid cuts")

View File

@ -741,7 +741,7 @@ def main():
return True
# TODO: FIX ME
# test_sets_cuts = data_module.test_cuts()
# test_sets_cuts = data_module.test_cuts_belle()
test_sets_cuts = data_module.test_cuts_en_vocalnet()
test_sets = test_sets_cuts.keys()
test_dls = [

View File

@ -11,3 +11,5 @@ flash-attn
peft
torchmetrics
# triton==3.3.0 # may be violate with openai-whisper
gradio
sherpa-onnx

View File

@ -1,175 +0,0 @@
from typing import Callable, Dict, List, Union
import torch
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix
from torch.utils.data.dataloader import DataLoader, default_collate
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech recognition task using k2 library.
This dataset expects to be queried with lists of cut IDs,
for which it loads features and automatically collates/batches them.
To use it with a PyTorch DataLoader, set ``batch_size=None``
and provide a :class:`SimpleCutSampler` sampler.
Each item in this dataset is a dict of:
.. code-block::
{
'inputs': float tensor with shape determined by :attr:`input_strategy`:
- single-channel:
- features: (B, T, F)
- audio: (B, T)
- multi-channel: currently not supported
'supervisions': [
{
'sequence_idx': Tensor[int] of shape (S,)
'text': List[str] of len S
# For feature input strategies
'start_frame': Tensor[int] of shape (S,)
'num_frames': Tensor[int] of shape (S,)
# For audio input strategies
'start_sample': Tensor[int] of shape (S,)
'num_samples': Tensor[int] of shape (S,)
# Optionally, when return_cuts=True
'cut': List[AnyCut] of len S
}
]
}
Dimension symbols legend:
* ``B`` - batch size (number of Cuts)
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions)
* ``T`` - number of frames of the longest Cut
* ``F`` - number of features
The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset.
"""
def __init__(
self,
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
):
"""
k2 ASR IterableDataset constructor.
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
objects used to create that batch.
:param cut_transforms: A list of transforms to be applied on each sampled batch,
before converting cuts to an input representation (audio/features).
Examples: cut concatenation, noise cuts mixing, etc.
:param input_transforms: A list of transforms to be applied on each sampled batch,
after the cuts are converted to audio/features.
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy
# This attribute is a workaround to constantly growing HDF5 memory
# throughout the epoch. It regularly closes open file handles to
# reset the internal HDF5 caches.
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_duration and max_cuts.
"""
validate_for_asr(cuts)
self.hdf5_fix.update()
# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
# the supervision boundaries.
for tnfm in self.cut_transforms:
cuts = tnfm(cuts)
# Sort the cuts again after transforms
cuts = cuts.sort_by_duration(ascending=False)
# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl
# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
# Apply all available transforms on the inputs, i.e. either audio or features.
# This could be feature extraction, global MVN, SpecAugment, etc.
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
batch["supervisions"]["cut"] = [
cut for cut in cuts for sup in cut.supervisions
]
return batch
def validate_for_asr(cuts: CutSet) -> None:
validate(cuts)
tol = 2e-3 # 1ms
for cut in cuts:
for supervision in cut.supervisions:
assert supervision.start >= -tol, (
f"Supervisions starting before the cut are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
#
# 'supervision.end' is end of supervision inside the Cut
assert supervision.end <= cut.duration + tol, (
f"Supervisions ending after the cut "
f"are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)

View File

@ -89,12 +89,6 @@ except RuntimeError:
pass
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
@ -143,6 +137,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to enable speech codec output.",
)
parser.add_argument(
"--speech-tokenizer-type",
type=str,
default="cosyvoice2",
help="The type of the speech tokenizer. cosyvoice2: 6561, cosyvoice1: 4096",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -229,10 +230,10 @@ def get_parser():
)
parser.add_argument(
"--dataset-format",
"--prompt-template",
type=str,
default="slam_omni",
help="The format of the dataset.",
default="speech_qa",
help="The prompt template to use.",
)
parser.add_argument(
@ -291,123 +292,89 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 1000,
# "env_info": get_env_info(),
}
)
return params
def process_batch_slam_omni(batch: dict):
def extract_text_and_speech_token(
batch: dict,
prompt_template: str,
enable_speech_output: bool
) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]:
"""
Extracts messages and speech tokens from a batch based on the dataset format.
Uses the global DEFAULT_SPEECH_TOKEN.
"""
messages = []
speech_tokens = None # Initialize as None
if enable_speech_output:
if "answer_cosyvoice_speech_token" in batch["supervisions"]["cut"][0].custom:
assert "speech_token" not in batch["supervisions"]["cut"][0].custom
speech_tokens = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
elif "speech_token" in batch["supervisions"]["cut"][0].custom:
speech_tokens = [
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
]
else:
raise ValueError("Unknown speech token type")
answers = batch["supervisions"]["text"]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
batch_size = len(answers)
messages = []
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()},
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
if prompt_template == "speech_qa":
for i in range(batch_size):
message_list_item = []
if 'round' in batch["supervisions"]["cut"][i].custom:
# slam_omni format dataset
current_question_with_history = batch["supervisions"]["cut"][i].custom["question"]
total_round = batch["supervisions"]["cut"][i].custom["round"]
history_context = current_question_with_history.rsplit("<USER>:", 1)[0].strip()
if total_round > 1:
history_question_answer = history_context.split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
question_answer = history_question_answer[j].split("ASSISTANT:")
message_list_item += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()},
]
message_list_item += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message_list_item)
elif prompt_template == "speech_continuation":
# speech_tokens remains None
for i in range(batch_size):
message_list_item = [
{
"role": "user",
"content": f"Continue the following text using less than 50 words:\\n\\n{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": answers[i]},
]
messages.append(message_list_item)
def process_batch_vocalnet(batch: dict):
answers = batch["supervisions"]["text"]
answer_cosyvoice_speech_token = [
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
]
messages = []
for i in range(len(answers)):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
def process_batch_text_vocalnet(batch: dict):
pass
answers = batch["supervisions"]["text"]
answer_cosyvoice_speech_token = [
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
]
messages = []
for i in range(len(answers)):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
def process_batch_speech_continuation(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
message = [
{
"role": "user",
"content": f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
]
# transcript = batch["supervisions"]["cut"][i].custom["text"]
messages.append(message)
return messages
def process_batch_asr(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
transcript = batch["supervisions"]["cut"][i].custom["text"]
message = [
{
"role": "user",
"content": f"Transcribe the following audio into text:\n\n{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": transcript},
]
messages.append(message)
return messages
def process_batch_text_continuation(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
transcript = batch["supervisions"]["cut"][i].custom["text"]
message = [
{
"role": "user",
"content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
]
messages.append(message)
return messages
elif prompt_template == "asr":
# speech_tokens remains None
for i in range(batch_size):
message_list_item = [
{
"role": "user",
"content": f"Transcribe the following audio into text:\\n\\n{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": answers[i]},
]
messages.append(message_list_item)
else:
raise ValueError(f"Unknown prompt template: {prompt_template}")
return messages, speech_tokens
def preprocess(
messages,
@ -459,6 +426,19 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def process_batch_text_continuation(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
transcript = batch["supervisions"]["cut"][i].custom["text"]
message = [
{
"role": "user",
"content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
]
messages.append(message)
return messages
def preprocess_teacher(
messages,
@ -551,20 +531,9 @@ def compute_loss(
feature = feature.transpose(1, 2) # (N, C, T)
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
if params.dataset_format == "slam_omni":
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
elif params.dataset_format == "vocalnet":
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
if params.loss_type == "kl_div":
messages_text = process_batch_text_vocalnet(batch)
elif params.dataset_format == "speech_continuation":
messages = process_batch_speech_continuation(batch)
if params.loss_type == "kl_div":
messages_text = process_batch_text_continuation(batch)
elif params.dataset_format == "asr":
messages = process_batch_asr(batch)
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
messages, answer_cosyvoice_speech_token = extract_text_and_speech_token(
batch, params.prompt_template, params.enable_speech_output
)
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
@ -581,6 +550,8 @@ def compute_loss(
labels=target_ids.to(device),
)
elif params.loss_type == "kl_div":
assert params.prompt_template == "speech_continuation"
messages_text = process_batch_text_continuation(batch)
(
teacher_input_ids,
teacher_attention_mask,
@ -598,6 +569,7 @@ def compute_loss(
else:
raise ValueError(f"Unknown loss type: {params.loss_type}")
else:
assert params.loss_type == "ce"
(
text_loss,
acc,
@ -918,13 +890,13 @@ def run(rank, world_size, args):
else:
attn_implementation = "eager"
torch_dtype = torch.float16
if params.dataset_format == "slam_omni":
codec_vocab_size = 4096 + 4
elif params.dataset_format == "vocalnet":
if params.speech_tokenizer_type == "cosyvoice2":
codec_vocab_size = 6561 + 4
elif params.speech_tokenizer_type == "cosyvoice1":
codec_vocab_size = 4096 + 4
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
# TODO: modify above vocab size or supress_tokens when decoding
raise ValueError(f"Unknown speech tokenizer type: {params.speech_tokenizer_type}")
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
@ -1029,24 +1001,23 @@ def run(rank, world_size, args):
return False
return True
if params.dataset_format == "slam_omni":
train_cuts = data_module.train_cuts()
valid_cuts = data_module.dev_cuts()
elif params.dataset_format == "vocalnet":
if params.dataset == "slam_omni_belle":
train_cuts = data_module.train_cuts_belle()
valid_cuts = data_module.dev_cuts_belle()
elif params.dataset == "vocalnet_ultrachat_voiceassistant":
train_cuts = data_module.train_cuts_en_vocalnet()
valid_cuts = data_module.valid_cuts_en_vocalnet()
elif params.dataset_format == "speech_continuation" or params.dataset_format == "asr":
if params.dataset == "multi_en":
train_cuts = data_module.train_cuts_ultravox()
elif params.dataset == "librispeech":
train_cuts = data_module.train_cuts_librispeech()
elif params.dataset == "gigaspeech":
train_cuts = data_module.train_cuts_gigaspeech()
else:
raise ValueError(f"Unknown dataset: {params.dataset}")
elif params.dataset == "ultravox_multi_en":
train_cuts = data_module.train_cuts_ultravox()
valid_cuts = data_module.valid_cuts_ultravox()
elif params.dataset == "librispeech":
train_cuts = data_module.train_cuts_librispeech()
valid_cuts = data_module.valid_cuts_ultravox()
elif params.dataset == "gigaspeech":
train_cuts = data_module.train_cuts_gigaspeech()
valid_cuts = data_module.valid_cuts_ultravox()
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
raise ValueError(f"Unknown dataset: {params.dataset}")
train_cuts = train_cuts.filter(remove_short_and_long_utt)
valid_cuts = valid_cuts.filter(remove_short_and_long_utt)

View File

@ -8,11 +8,10 @@ import random
import re
import subprocess
from collections import defaultdict
# from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
# from shutil import copyfile
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import torch