update hf dataset loading into lhotse

This commit is contained in:
root 2025-04-29 07:33:34 +00:00
parent d742043e75
commit 448a4eeea7
5 changed files with 229 additions and 124 deletions

View File

@ -2,6 +2,7 @@
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
# Copyright 2025 Nvidia (Yuekai Zhang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -23,12 +24,7 @@ from pathlib import Path
import torch
from datasets import load_dataset
from lhotse import (
CutSet,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
from icefall.utils import str2bool
@ -93,7 +89,12 @@ def get_parser():
default="answer",
help="The key in the Huggingface dataset containing the text data",
)
parser.add_argument(
"--prefix",
type=str,
default="belle",
help="""The dataset prefix to use when saving the features""",
)
return parser
@ -114,27 +115,28 @@ def compute_fbank(args):
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
)
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
raise NotImplementedError("Only WhisperFbank is implemented.")
logging.info(f"device: {device}")
start = 0
stop = 1601
dataset = load_dataset(
args.huggingface_dataset_path_or_name, streaming=True, split="train"
)
num_shards = dataset.num_shards
num_digits = 5
for i in range(start, stop):
for i in range(num_shards):
shard = dataset.shard(num_shards, i)
shard = shard.take(10) # for testing
logging.info(
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
)
idx = f"{i}".zfill(num_digits)
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
parquet_files = [
f"data/train-{idx}-of-01601.parquet",
]
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key)
cut_set = CutSet.from_huggingface_dataset(
shard, audio_key=args.audio_key, text_key=args.text_key
)
logging.info("Splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
@ -153,22 +155,13 @@ def compute_fbank(args):
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cuts_path = f"{in_out_dir}/cuts_belle.{idx}.jsonl.gz"
cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz"
logging.info(f"Saving to {cuts_path}")
# cut_set.to_file(cuts_path)
remove_recording_item(cut_set, cuts_path)
# see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.drop_recordings().to_file(cuts_path)
if i > 1:
break
def remove_recording_item(
cuts,
output_cuts,
):
"""
don't store recording item
"""
with CutSet.open_writer(output_cuts) as writer:
for cut in cuts:
cut.recording.sources = None
writer.write(cut)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -20,10 +20,10 @@ log() {
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: "
pip uninstall lhotse
cd /workspace/slam/lhotse
git config --global --add safe.directory /workspace/slam/lhotse
pip install -e '.[dev]'
#pip uninstall lhotse
#cd /workspace/slam/lhotse
#git config --global --add safe.directory /workspace/slam/lhotse
#pip install -e '.[dev]'
cd -
pip install -r slam_omni/requirements.txt
fi
@ -31,7 +31,12 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
python3 local/compute_whisper_fbank.py
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_test \
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix belle
fi
@ -52,16 +57,18 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: "
exp_dir=./slam_omni/exp_speech2speech_rerun
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
python3 ./slam_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 997 --avg 1 \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method small_test_speech2speech_rerun \
--method e2e-epoch10_speech2speech_rerun \
--enable-speech-output True \
--token2wav-path /workspace/CosyVoice-300M-SFT \
--use-lora True # --on-the-fly-feats True
fi

View File

@ -24,7 +24,14 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, load_manifest, load_manifest_lazy
from datasets import load_dataset
from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
load_manifest,
load_manifest_lazy,
)
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
@ -38,11 +45,11 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader
from datasets import load_dataset
from icefall.utils import str2bool
from speech_dataset import K2SpeechRecognitionDataset
class _SeedWorkers:
def __init__(self, seed: int):
@ -310,7 +317,9 @@ class AsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -365,7 +374,9 @@ class AsrDataModule:
logging.info("About to create dev dataset")
validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
@ -390,7 +401,9 @@ class AsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu')))
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
@ -419,16 +432,27 @@ class AsrDataModule:
parquet_files = [
f"data/train-{idx}-of-01601.parquet",
]
parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
parquet_files = [
f"{self.args.huggingface_dataset_path_or_name}/{f}"
for f in parquet_files
]
file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key)
dataset = load_dataset(
"parquet", data_files=parquet_files, streaming=True, split="train"
)
cut_set = CutSet.from_huggingface_dataset(
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
)
if self.args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
return {'test':cut_set}
return {"test": cut_set}
else:
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
return {
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz")
}
@lru_cache()
def dev_cuts(self) -> CutSet:
@ -436,8 +460,9 @@ class AsrDataModule:
if self.args.on_the_fly_feats:
pass
else:
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz"
)
@lru_cache()
def train_cuts(self) -> CutSet:

View File

@ -48,32 +48,74 @@ python3 ./whisper_llm_zh/decode.py \
import argparse
import logging
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import soundfile as sf
import torch
import torch.nn as nn
import transformers
import whisper
from cosyvoice.cli.cosyvoice import CosyVoice
from data_module import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from train import add_model_arguments
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
average_checkpoints,
)
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
"""
Generate audio from tokens with optional tone and prompt embedding.
Args:
audio_tokens (list): List of audio tokens to be processed.
codec_decoder: Codec decoder for generating audio.
Returns:
torch.Tensor: Generated audio waveform.
"""
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
prompt_token_len=torch.tensor(
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
prompt_feat_len=torch.tensor(
[prompt_speech_feat.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device),
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def get_model(params, device):
"""Load and prepare the speech-to-speech model."""
if params.remove_whisper_encoder_input_length_restriction:
@ -136,7 +178,7 @@ def get_model(params, device):
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
@ -162,7 +204,7 @@ def get_model(params, device):
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype
torch_dtype=torch_dtype,
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
@ -197,7 +239,7 @@ def get_model(params, device):
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
if params.avg > 1:
@ -325,6 +367,12 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--token2wav-path",
type=str,
default="/workspace/CosyVoice-300M-SFT",
help="The path to the token2wav model",
)
# parser.add_argument(
# "--dataset",
# type=str,
@ -350,6 +398,7 @@ def decode_one_batch(
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
token2wav_model: nn.Module,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -431,26 +480,32 @@ def decode_one_batch(
# {"role": "assistant", "content": ""},
# ]
# ] * len(feature)
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
messages = []
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split('USER:')
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:')
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()}
{"role": "assistant", "content": question_answer[1].strip()},
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""}
{"role": "assistant", "content": ""},
]
print(f"message: {message}, batch_size {len(chat_rounds)}")
messages.append(message)
@ -461,16 +516,21 @@ def decode_one_batch(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
for cut_id in cut_ids:
speech_token_file_name = (
params.log_dir / f"{cut_id}.txt"
)
with open(speech_token_file_name, 'w') as f:
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
#torchaudio.save(save_path, speech_output.cpu(), 16000)
# print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
save_str = " ".join([str(i) for i in generated_speech_output])
f.write(f"{cut_id}|{save_str}\n")
generated_speech_output = [
generated_speech_output
] # WAR: only support batch = 1 for now
for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
speech_file_name = params.log_dir / f"{cut_id}.wav"
audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
# with open(speech_token_file_name, 'w') as f:
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
# save_str = " ".join([str(i) for i in generated_speech_output])
# f.write(f"{cut_id}|{save_str}\n")
else:
generated_ids = model.decode(
@ -486,6 +546,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
token2wav_model: nn.Module,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -548,14 +609,23 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
answers = batch["supervisions"]["text"]
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
texts = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
answer_cosyvoice_speech_token = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
texts = [
question.split("<USER>: ")[-1].strip()
for question in questions_with_history
]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
token2wav_model=token2wav_model,
batch=batch,
tokenizer=tokenizer,
)
@ -643,9 +713,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
params.log_dir.mkdir(parents=True, exist_ok=True)
setup_logger(
f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
)
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
logging.info("Decoding started")
logging.info(params)
@ -657,6 +725,9 @@ def main():
logging.info(f"device: {device}")
model, tokenizer = get_model(params, device)
token2wav_model = CosyVoice(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -697,6 +768,7 @@ def main():
dl=test_dl,
params=params,
model=model,
token2wav_model=token2wav_model,
tokenizer=tokenizer,
)

View File

@ -66,8 +66,9 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
# from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import (
@ -146,6 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to enable speech codec output.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -332,9 +334,7 @@ def compute_loss(
# remove too long text
# texts = [ text for text in texts if len(text) < 1024 ]
if len(texts) != len(messages):
logging.warning(
f"Remove too long text, {messages} "
)
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
@ -354,10 +354,10 @@ def compute_loss(
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(
input_ids == default_speech_token_id
default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
@ -382,11 +382,20 @@ def compute_loss(
batch_idx_train = params.batch_idx_train
answers = batch["supervisions"]["text"]
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
answer_cosyvoice_speech_token = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
@ -394,18 +403,18 @@ def compute_loss(
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split('USER:')
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:')
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()}
{"role": "assistant", "content": question_answer[1].strip()},
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]}
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
@ -423,7 +432,13 @@ def compute_loss(
labels=target_ids.to(device),
)
else:
text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output(
(
text_loss,
acc,
codec_loss,
codec_acc,
codec_topk_acc,
) = model.forward_with_speech_output(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
@ -445,12 +460,8 @@ def compute_loss(
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
if params.enable_speech_output:
info["codec_acc"] = (
codec_acc * info["frames"]
)
info["codec_topk_acc"] = (
codec_topk_acc * info["frames"]
)
info["codec_acc"] = codec_acc * info["frames"]
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info
@ -469,7 +480,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast('cuda', enabled=params.use_fp16):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
@ -584,7 +595,7 @@ def train_one_epoch(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
with torch.amp.autocast('cuda', enabled=params.use_fp16):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
@ -722,7 +733,6 @@ def run(rank, world_size, args):
# model.resize_token_embeddings(len(tokenizer))
# model.vocab_size = len(tokenizer)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
@ -736,12 +746,11 @@ def run(rank, world_size, args):
param.requires_grad = False
encoder_projector.eval()
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
@ -768,7 +777,7 @@ def run(rank, world_size, args):
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype
torch_dtype=torch_dtype,
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
@ -803,7 +812,7 @@ def run(rank, world_size, args):
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
if params.pretrained_model_path:
@ -851,12 +860,11 @@ def run(rank, world_size, args):
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
if codec_len > 2200:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
)
return False
return True
train_cuts = data_module.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)