From 448a4eeea79281753efc7406f3130f5d498fa19b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 29 Apr 2025 07:33:34 +0000 Subject: [PATCH] update hf dataset loading into lhotse --- .../local/compute_whisper_fbank.py | 63 ++++---- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 27 ++-- .../SPEECH2SPEECH/slam_omni/data_module.py | 53 +++++-- .../SPEECH2SPEECH/slam_omni/decode.py | 134 ++++++++++++++---- .../SPEECH2SPEECH/slam_omni/train.py | 76 +++++----- 5 files changed, 229 insertions(+), 124 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py index 1c3a3d1e0..b01a35c7d 100755 --- a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py +++ b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py @@ -2,6 +2,7 @@ # Copyright 2021 Johns Hopkins University (Piotr Żelasko) # Copyright 2021 Xiaomi Corp. (Fangjun Kuang) # Copyright 2023 Xiaomi Corp. (Zengrui Jin) +# Copyright 2025 Nvidia (Yuekai Zhang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -23,12 +24,7 @@ from pathlib import Path import torch from datasets import load_dataset -from lhotse import ( - CutSet, - LilcomChunkyWriter, - WhisperFbank, - WhisperFbankConfig, -) +from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig from icefall.utils import str2bool @@ -93,7 +89,12 @@ def get_parser(): default="answer", help="The key in the Huggingface dataset containing the text data", ) - + parser.add_argument( + "--prefix", + type=str, + default="belle", + help="""The dataset prefix to use when saving the features""", + ) return parser @@ -114,27 +115,28 @@ def compute_fbank(args): WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) ) else: - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + raise NotImplementedError("Only WhisperFbank is implemented.") logging.info(f"device: {device}") - start = 0 - stop = 1601 + dataset = load_dataset( + args.huggingface_dataset_path_or_name, streaming=True, split="train" + ) + num_shards = dataset.num_shards num_digits = 5 - for i in range(start, stop): + for i in range(num_shards): + shard = dataset.shard(num_shards, i) + shard = shard.take(10) # for testing + logging.info( + f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}" + ) + idx = f"{i}".zfill(num_digits) - # dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition) - parquet_files = [ - f"data/train-{idx}-of-01601.parquet", - ] - parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files] - file_name = parquet_files[0] - logging.info(f"Loading dataset from {file_name}") - dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train') - cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key) + cut_set = CutSet.from_huggingface_dataset( + shard, audio_key=args.audio_key, text_key=args.text_key + ) - logging.info("Splitting cuts into smaller chunks") cut_set = cut_set.trim_to_supervisions( keep_overlapping=False, min_duration=None ) @@ -153,22 +155,13 @@ def compute_fbank(args): storage_type=LilcomChunkyWriter, overwrite=True, ) - cuts_path = f"{in_out_dir}/cuts_belle.{idx}.jsonl.gz" + cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz" logging.info(f"Saving to {cuts_path}") - # cut_set.to_file(cuts_path) - remove_recording_item(cut_set, cuts_path) + # see https://github.com/lhotse-speech/lhotse/issues/1125 + cut_set.drop_recordings().to_file(cuts_path) + if i > 1: + break -def remove_recording_item( - cuts, - output_cuts, -): - """ - don't store recording item - """ - with CutSet.open_writer(output_cuts) as writer: - for cut in cuts: - cut.recording.sources = None - writer.write(cut) def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 1b49daa65..47320ab66 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -20,10 +20,10 @@ log() { if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "stage 0: " - pip uninstall lhotse - cd /workspace/slam/lhotse - git config --global --add safe.directory /workspace/slam/lhotse - pip install -e '.[dev]' + #pip uninstall lhotse + #cd /workspace/slam/lhotse + #git config --global --add safe.directory /workspace/slam/lhotse + #pip install -e '.[dev]' cd - pip install -r slam_omni/requirements.txt fi @@ -31,7 +31,12 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" - python3 local/compute_whisper_fbank.py + python3 local/compute_whisper_fbank.py \ + --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ + --out-dir data/fbank_test \ + --huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \ + --audio-key question_audio --text-key answer \ + --prefix belle fi @@ -42,7 +47,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort) # # remove cust_belle_00000.jsonl.gz from pieces # pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g') - echo $pieces | wc + echo $pieces | wc lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd - fi @@ -52,16 +57,18 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "stage 3: " exp_dir=./slam_omni/exp_speech2speech_rerun + export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice python3 ./slam_omni/decode.py \ --max-duration 1 \ --exp-dir $exp_dir \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --epoch 997 --avg 1 \ + --epoch 999 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --method small_test_speech2speech_rerun \ + --method e2e-epoch10_speech2speech_rerun \ --enable-speech-output True \ + --token2wav-path /workspace/CosyVoice-300M-SFT \ --use-lora True # --on-the-fly-feats True fi @@ -120,7 +127,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then --use-flash-attn True \ --enable-speech-output True \ --asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \ - --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share + --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share fi @@ -133,4 +140,4 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local fi -fi \ No newline at end of file +fi diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py index 11e3bc779..7cab52f73 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py @@ -24,7 +24,14 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, load_manifest, load_manifest_lazy +from datasets import load_dataset +from lhotse import ( + CutSet, + WhisperFbank, + WhisperFbankConfig, + load_manifest, + load_manifest_lazy, +) from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -38,11 +45,11 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples OnTheFlyFeatures, ) from lhotse.utils import fix_random_seed +from speech_dataset import K2SpeechRecognitionDataset from torch.utils.data import DataLoader -from datasets import load_dataset from icefall.utils import str2bool -from speech_dataset import K2SpeechRecognitionDataset + class _SeedWorkers: def __init__(self, seed: int): @@ -310,7 +317,9 @@ class AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))), + input_strategy=OnTheFlyFeatures( + WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -365,7 +374,9 @@ class AsrDataModule: logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))) + input_strategy=OnTheFlyFeatures( + WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) + ) if self.args.on_the_fly_feats else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, @@ -390,7 +401,9 @@ class AsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu'))) + input_strategy=OnTheFlyFeatures( + WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu")) + ) if self.args.on_the_fly_feats else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, @@ -419,16 +432,27 @@ class AsrDataModule: parquet_files = [ f"data/train-{idx}-of-01601.parquet", ] - parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files] + parquet_files = [ + f"{self.args.huggingface_dataset_path_or_name}/{f}" + for f in parquet_files + ] file_name = parquet_files[0] logging.info(f"Loading dataset from {file_name}") - dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train') - cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key) + dataset = load_dataset( + "parquet", data_files=parquet_files, streaming=True, split="train" + ) + cut_set = CutSet.from_huggingface_dataset( + dataset, audio_key=self.args.audio_key, text_key=self.args.text_key + ) if self.args.resample_to_16kHz: cut_set = cut_set.resample(16000) - return {'test':cut_set} + return {"test": cut_set} else: - return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")} + # return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")} + # return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")} + return { + "test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz") + } @lru_cache() def dev_cuts(self) -> CutSet: @@ -436,10 +460,11 @@ class AsrDataModule: if self.args.on_the_fly_feats: pass else: - return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz") - + return load_manifest_lazy( + self.args.manifest_dir / "cuts_belle.00000.jsonl.gz" + ) @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz") \ No newline at end of file + return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz") diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index 5cda487e3..acd882d18 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -23,7 +23,7 @@ Usage: pip install huggingface_hub['cli'] mkdir -p models/whisper models/qwen models/checkpoint -huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B +huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B # For aishell fine-tuned whisper model huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt @@ -48,32 +48,74 @@ python3 ./whisper_llm_zh/decode.py \ import argparse import logging +import sys from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple +import soundfile as sf import torch import torch.nn as nn import transformers import whisper +from cosyvoice.cli.cosyvoice import CosyVoice from data_module import AsrDataModule from lhotse.cut import Cut from model import SPEECH_LLM, EncoderProjector - from peft import LoraConfig, get_peft_model -from train import DEFAULT_SPEECH_TOKEN +from train import DEFAULT_SPEECH_TOKEN, add_model_arguments from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from train import add_model_arguments + from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, write_error_stats, - average_checkpoints, ) +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + + +def audio_decode_cosyvoice(audio_tokens, codec_decoder): + """ + Generate audio from tokens with optional tone and prompt embedding. + + Args: + audio_tokens (list): List of audio tokens to be processed. + codec_decoder: Codec decoder for generating audio. + + Returns: + torch.Tensor: Generated audio waveform. + """ + flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"] + flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) + prompt_speech_feat = torch.zeros(1, 0, 80) + tts_mel, _ = codec_decoder.model.flow.inference( + token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( + codec_decoder.model.device + ), + prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device), + prompt_token_len=torch.tensor( + [flow_prompt_speech_token.shape[1]], dtype=torch.int32 + ).to(codec_decoder.model.device), + prompt_feat=prompt_speech_feat.to(codec_decoder.model.device), + prompt_feat_len=torch.tensor( + [prompt_speech_feat.shape[1]], dtype=torch.int32 + ).to(codec_decoder.model.device), + embedding=flow_embedding.to(codec_decoder.model.device), + flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device), + ) + + audio_hat, _ = codec_decoder.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + return audio_hat + + def get_model(params, device): """Load and prepare the speech-to-speech model.""" if params.remove_whisper_encoder_input_length_restriction: @@ -136,7 +178,7 @@ def get_model(params, device): # Determine attn_implementation and torch_dtype based on use_flash_attn if params.use_flash_attn: attn_implementation = "flash_attention_2" - torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported + torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported else: attn_implementation = "eager" torch_dtype = torch.float16 @@ -162,7 +204,7 @@ def get_model(params, device): codec_lm = AutoModelForCausalLM.from_config( config=config, attn_implementation=attn_implementation, - torch_dtype=torch_dtype + torch_dtype=torch_dtype, ) # cosyvoice2_token_size = 6561 codec_lm.resize_token_embeddings(codec_vocab_size) @@ -197,7 +239,7 @@ def get_model(params, device): llm, encoder_projector, codec_lm, - codec_lm_padding_side= "left" if params.use_flash_attn else "right", + codec_lm_padding_side="left" if params.use_flash_attn else "right", ) if params.avg > 1: @@ -325,6 +367,12 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--token2wav-path", + type=str, + default="/workspace/CosyVoice-300M-SFT", + help="The path to the token2wav model", + ) # parser.add_argument( # "--dataset", # type=str, @@ -350,6 +398,7 @@ def decode_one_batch( params: AttributeDict, model: nn.Module, tokenizer: AutoTokenizer, + token2wav_model: nn.Module, batch: dict, ) -> Dict[str, List[List[int]]]: """Decode one batch and return the result in a dict. The dict has the @@ -431,26 +480,32 @@ def decode_one_batch( # {"role": "assistant", "content": ""}, # ] # ] * len(feature) - questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] - history_contexts = [question.rsplit(':', 1)[0].strip() for question in questions_with_history] - last_questions = [question.split(': ')[-1].strip() for question in questions_with_history] + questions_with_history = [ + cut.custom["question"] for cut in batch["supervisions"]["cut"] + ] + history_contexts = [ + question.rsplit(":", 1)[0].strip() for question in questions_with_history + ] + last_questions = [ + question.split(": ")[-1].strip() for question in questions_with_history + ] messages = [] for i, total_round in enumerate(chat_rounds): message = [] if total_round > 1: - history_question_answer = history_contexts[i].split('USER:') + history_question_answer = history_contexts[i].split("USER:") history_question_answer = [item for item in history_question_answer if item] for j in range(total_round - 1): # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 - question_answer = history_question_answer[j].split('ASSISTANT:') + question_answer = history_question_answer[j].split("ASSISTANT:") message += [ {"role": "user", "content": question_answer[0].strip()}, - {"role": "assistant", "content": question_answer[1].strip()} + {"role": "assistant", "content": question_answer[1].strip()}, ] message += [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, # {"role": "user", "content": f"{last_questions[i]}"}, - {"role": "assistant", "content": ""} + {"role": "assistant", "content": ""}, ] print(f"message: {message}, batch_size {len(chat_rounds)}") messages.append(message) @@ -461,16 +516,21 @@ def decode_one_batch( feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) ) cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - for cut_id in cut_ids: - speech_token_file_name = ( - params.log_dir / f"{cut_id}.txt" - ) - with open(speech_token_file_name, 'w') as f: - # save_path = params.exp_dir / f"speech_output/{cut_id}.wav" - #torchaudio.save(save_path, speech_output.cpu(), 16000) - # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") - save_str = " ".join([str(i) for i in generated_speech_output]) - f.write(f"{cut_id}|{save_str}\n") + generated_speech_output = [ + generated_speech_output + ] # WAR: only support batch = 1 for now + for cut_id, audio_tokens in zip(cut_ids, generated_speech_output): + speech_file_name = params.log_dir / f"{cut_id}.wav" + audio_tokens = [token for token in audio_tokens if token < 4096] + audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) + audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) + sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050) + # with open(speech_token_file_name, 'w') as f: + # # save_path = params.exp_dir / f"speech_output/{cut_id}.wav" + # #torchaudio.save(save_path, speech_output.cpu(), 16000) + # # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") + # save_str = " ".join([str(i) for i in generated_speech_output]) + # f.write(f"{cut_id}|{save_str}\n") else: generated_ids = model.decode( @@ -486,6 +546,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, tokenizer: AutoTokenizer, + token2wav_model: nn.Module, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -548,14 +609,23 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): answers = batch["supervisions"]["text"] - questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] - answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] - texts = [question.split(': ')[-1].strip() for question in questions_with_history] + questions_with_history = [ + cut.custom["question"] for cut in batch["supervisions"]["cut"] + ] + answer_cosyvoice_speech_token = [ + cut.custom["answer_cosyvoice_speech_token"] + for cut in batch["supervisions"]["cut"] + ] + texts = [ + question.split(": ")[-1].strip() + for question in questions_with_history + ] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, model=model, + token2wav_model=token2wav_model, batch=batch, tokenizer=tokenizer, ) @@ -643,9 +713,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.log_dir = Path(params.exp_dir) / f"log-{params.method}" params.log_dir.mkdir(parents=True, exist_ok=True) - setup_logger( - f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}" - ) + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}") logging.info("Decoding started") logging.info(params) @@ -657,6 +725,9 @@ def main(): logging.info(f"device: {device}") model, tokenizer = get_model(params, device) + token2wav_model = CosyVoice( + params.token2wav_path, load_jit=False, load_trt=False, fp16=False + ) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -697,6 +768,7 @@ def main(): dl=test_dl, params=params, model=model, + token2wav_model=token2wav_model, tokenizer=tokenizer, ) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index 3b971dd89..1438a2624 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -66,8 +66,9 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector + # from multi_dataset import MultiDataset -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from torch import Tensor from torch.utils.tensorboard import SummaryWriter from transformers import ( @@ -146,6 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to enable speech codec output.", ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -332,9 +334,7 @@ def compute_loss( # remove too long text # texts = [ text for text in texts if len(text) < 1024 ] if len(texts) != len(messages): - logging.warning( - f"Remove too long text, {messages} " - ) + logging.warning(f"Remove too long text, {messages} ") max_len_texts = max([len(text) for text in texts]) if tokenizer.padding_side == "right": texts = [ @@ -354,10 +354,10 @@ def compute_loss( # first get the indices of the tokens mask_prompt = True if mask_prompt: - default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) - mask_indices = torch.where( - input_ids == default_speech_token_id + default_speech_token_id = tokenizer.convert_tokens_to_ids( + DEFAULT_SPEECH_TOKEN ) + mask_indices = torch.where(input_ids == default_speech_token_id) for i in range(mask_indices[0].size(0)): row = mask_indices[0][i] col = mask_indices[1][i] @@ -382,30 +382,39 @@ def compute_loss( batch_idx_train = params.batch_idx_train answers = batch["supervisions"]["text"] - questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] + questions_with_history = [ + cut.custom["question"] for cut in batch["supervisions"]["cut"] + ] chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] - answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] - last_questions = [question.split(': ')[-1].strip() for question in questions_with_history] - history_contexts = [question.rsplit(':', 1)[0].strip() for question in questions_with_history] - # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。: 告诉我如何烹饪鸡肉 + answer_cosyvoice_speech_token = [ + cut.custom["answer_cosyvoice_speech_token"] + for cut in batch["supervisions"]["cut"] + ] + last_questions = [ + question.split(": ")[-1].strip() for question in questions_with_history + ] + history_contexts = [ + question.rsplit(":", 1)[0].strip() for question in questions_with_history + ] + # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。: 告诉我如何烹饪鸡肉 # : 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。 messages = [] for i, total_round in enumerate(chat_rounds): message = [] if total_round > 1: - history_question_answer = history_contexts[i].split('USER:') + history_question_answer = history_contexts[i].split("USER:") history_question_answer = [item for item in history_question_answer if item] for j in range(total_round - 1): # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 - question_answer = history_question_answer[j].split('ASSISTANT:') + question_answer = history_question_answer[j].split("ASSISTANT:") message += [ {"role": "user", "content": question_answer[0].strip()}, - {"role": "assistant", "content": question_answer[1].strip()} + {"role": "assistant", "content": question_answer[1].strip()}, ] message += [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, - {"role": "assistant", "content": answers[i]} + {"role": "assistant", "content": answers[i]}, ] messages.append(message) @@ -423,7 +432,13 @@ def compute_loss( labels=target_ids.to(device), ) else: - text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output( + ( + text_loss, + acc, + codec_loss, + codec_acc, + codec_topk_acc, + ) = model.forward_with_speech_output( fbank=feature, input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), @@ -445,12 +460,8 @@ def compute_loss( acc * info["frames"] ) # WAR: to avoid normalization by the number of frames if params.enable_speech_output: - info["codec_acc"] = ( - codec_acc * info["frames"] - ) - info["codec_topk_acc"] = ( - codec_topk_acc * info["frames"] - ) + info["codec_acc"] = codec_acc * info["frames"] + info["codec_topk_acc"] = codec_topk_acc * info["frames"] info["codec_loss"] = codec_loss.detach().cpu().item() info["text_loss"] = text_loss.detach().cpu().item() return loss, info @@ -469,7 +480,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.amp.autocast('cuda', enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -584,7 +595,7 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.amp.autocast('cuda', enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -722,7 +733,6 @@ def run(rank, world_size, args): # model.resize_token_embeddings(len(tokenizer)) # model.vocab_size = len(tokenizer) - llm.config.pad_token_id = tokenizer.pad_token_id llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( DEFAULT_SPEECH_TOKEN @@ -736,12 +746,11 @@ def run(rank, world_size, args): param.requires_grad = False encoder_projector.eval() - if params.enable_speech_output: # Determine attn_implementation and torch_dtype based on use_flash_attn if params.use_flash_attn: attn_implementation = "flash_attention_2" - torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported + torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported else: attn_implementation = "eager" torch_dtype = torch.float16 @@ -766,9 +775,9 @@ def run(rank, world_size, args): # Pass attn_implementation and torch_dtype to the constructor # Use AutoModelForCausalLM.from_config for more generality codec_lm = AutoModelForCausalLM.from_config( - config=config, - attn_implementation=attn_implementation, - torch_dtype=torch_dtype + config=config, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype, ) # cosyvoice2_token_size = 6561 codec_lm.resize_token_embeddings(codec_vocab_size) @@ -803,7 +812,7 @@ def run(rank, world_size, args): llm, encoder_projector, codec_lm, - codec_lm_padding_side= "left" if params.use_flash_attn else "right", + codec_lm_padding_side="left" if params.use_flash_attn else "right", ) if params.pretrained_model_path: @@ -851,12 +860,11 @@ def run(rank, world_size, args): codec_len = len(c.custom["answer_cosyvoice_speech_token"]) if codec_len > 2200: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" ) return False return True - train_cuts = data_module.train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt)