update hf dataset loading into lhotse

This commit is contained in:
root 2025-04-29 07:33:34 +00:00
parent d742043e75
commit 448a4eeea7
5 changed files with 229 additions and 124 deletions

View File

@ -2,6 +2,7 @@
# Copyright 2021 Johns Hopkins University (Piotr Żelasko) # Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
# Copyright 2023 Xiaomi Corp. (Zengrui Jin) # Copyright 2023 Xiaomi Corp. (Zengrui Jin)
# Copyright 2025 Nvidia (Yuekai Zhang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -23,12 +24,7 @@ from pathlib import Path
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from lhotse import ( from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
CutSet,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from icefall.utils import str2bool from icefall.utils import str2bool
@ -93,7 +89,12 @@ def get_parser():
default="answer", default="answer",
help="The key in the Huggingface dataset containing the text data", help="The key in the Huggingface dataset containing the text data",
) )
parser.add_argument(
"--prefix",
type=str,
default="belle",
help="""The dataset prefix to use when saving the features""",
)
return parser return parser
@ -114,27 +115,28 @@ def compute_fbank(args):
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
) )
else: else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) raise NotImplementedError("Only WhisperFbank is implemented.")
logging.info(f"device: {device}") logging.info(f"device: {device}")
start = 0 dataset = load_dataset(
stop = 1601 args.huggingface_dataset_path_or_name, streaming=True, split="train"
)
num_shards = dataset.num_shards
num_digits = 5 num_digits = 5
for i in range(start, stop): for i in range(num_shards):
shard = dataset.shard(num_shards, i)
shard = shard.take(10) # for testing
logging.info(
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
)
idx = f"{i}".zfill(num_digits) idx = f"{i}".zfill(num_digits)
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
parquet_files = [
f"data/train-{idx}-of-01601.parquet",
]
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key) cut_set = CutSet.from_huggingface_dataset(
shard, audio_key=args.audio_key, text_key=args.text_key
)
logging.info("Splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions( cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None keep_overlapping=False, min_duration=None
) )
@ -153,22 +155,13 @@ def compute_fbank(args):
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
overwrite=True, overwrite=True,
) )
cuts_path = f"{in_out_dir}/cuts_belle.{idx}.jsonl.gz" cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz"
logging.info(f"Saving to {cuts_path}") logging.info(f"Saving to {cuts_path}")
# cut_set.to_file(cuts_path) # see https://github.com/lhotse-speech/lhotse/issues/1125
remove_recording_item(cut_set, cuts_path) cut_set.drop_recordings().to_file(cuts_path)
if i > 1:
break
def remove_recording_item(
cuts,
output_cuts,
):
"""
don't store recording item
"""
with CutSet.open_writer(output_cuts) as writer:
for cut in cuts:
cut.recording.sources = None
writer.write(cut)
def main(): def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -20,10 +20,10 @@ log() {
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: " log "stage 0: "
pip uninstall lhotse #pip uninstall lhotse
cd /workspace/slam/lhotse #cd /workspace/slam/lhotse
git config --global --add safe.directory /workspace/slam/lhotse #git config --global --add safe.directory /workspace/slam/lhotse
pip install -e '.[dev]' #pip install -e '.[dev]'
cd - cd -
pip install -r slam_omni/requirements.txt pip install -r slam_omni/requirements.txt
fi fi
@ -31,7 +31,12 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
python3 local/compute_whisper_fbank.py python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_test \
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix belle
fi fi
@ -42,7 +47,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort) pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
# # remove cust_belle_00000.jsonl.gz from pieces # # remove cust_belle_00000.jsonl.gz from pieces
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g') # pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
echo $pieces | wc echo $pieces | wc
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd - cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd -
fi fi
@ -52,16 +57,18 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: " log "stage 3: "
exp_dir=./slam_omni/exp_speech2speech_rerun exp_dir=./slam_omni/exp_speech2speech_rerun
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
python3 ./slam_omni/decode.py \ python3 ./slam_omni/decode.py \
--max-duration 1 \ --max-duration 1 \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 997 --avg 1 \ --epoch 999 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--method small_test_speech2speech_rerun \ --method e2e-epoch10_speech2speech_rerun \
--enable-speech-output True \ --enable-speech-output True \
--token2wav-path /workspace/CosyVoice-300M-SFT \
--use-lora True # --on-the-fly-feats True --use-lora True # --on-the-fly-feats True
fi fi
@ -120,7 +127,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
--use-flash-attn True \ --use-flash-attn True \
--enable-speech-output True \ --enable-speech-output True \
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \ --asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi fi
@ -133,4 +140,4 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi fi
fi fi

View File

@ -24,7 +24,14 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, load_manifest, load_manifest_lazy from datasets import load_dataset
from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
load_manifest,
load_manifest_lazy,
)
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
CutMix, CutMix,
@ -38,11 +45,11 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
OnTheFlyFeatures, OnTheFlyFeatures,
) )
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from datasets import load_dataset
from icefall.utils import str2bool from icefall.utils import str2bool
from speech_dataset import K2SpeechRecognitionDataset
class _SeedWorkers: class _SeedWorkers:
def __init__(self, seed: int): def __init__(self, seed: int):
@ -310,7 +317,9 @@ class AsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))), input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -365,7 +374,9 @@ class AsrDataModule:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))) input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
)
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
@ -390,7 +401,9 @@ class AsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu'))) input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
)
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
@ -419,16 +432,27 @@ class AsrDataModule:
parquet_files = [ parquet_files = [
f"data/train-{idx}-of-01601.parquet", f"data/train-{idx}-of-01601.parquet",
] ]
parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files] parquet_files = [
f"{self.args.huggingface_dataset_path_or_name}/{f}"
for f in parquet_files
]
file_name = parquet_files[0] file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}") logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train') dataset = load_dataset(
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key) "parquet", data_files=parquet_files, streaming=True, split="train"
)
cut_set = CutSet.from_huggingface_dataset(
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
)
if self.args.resample_to_16kHz: if self.args.resample_to_16kHz:
cut_set = cut_set.resample(16000) cut_set = cut_set.resample(16000)
return {'test':cut_set} return {"test": cut_set}
else: else:
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")} # return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
return {
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz")
}
@lru_cache() @lru_cache()
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:
@ -436,10 +460,11 @@ class AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
pass pass
else: else:
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz"
)
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz") return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz")

View File

@ -23,7 +23,7 @@ Usage:
pip install huggingface_hub['cli'] pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen models/checkpoint mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model # For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
@ -48,32 +48,74 @@ python3 ./whisper_llm_zh/decode.py \
import argparse import argparse
import logging import logging
import sys
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import soundfile as sf
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import whisper import whisper
from cosyvoice.cli.cosyvoice import CosyVoice
from data_module import AsrDataModule from data_module import AsrDataModule
from lhotse.cut import Cut from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from train import add_model_arguments
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
average_checkpoints,
) )
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
"""
Generate audio from tokens with optional tone and prompt embedding.
Args:
audio_tokens (list): List of audio tokens to be processed.
codec_decoder: Codec decoder for generating audio.
Returns:
torch.Tensor: Generated audio waveform.
"""
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
prompt_token_len=torch.tensor(
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
prompt_feat_len=torch.tensor(
[prompt_speech_feat.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device),
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def get_model(params, device): def get_model(params, device):
"""Load and prepare the speech-to-speech model.""" """Load and prepare the speech-to-speech model."""
if params.remove_whisper_encoder_input_length_restriction: if params.remove_whisper_encoder_input_length_restriction:
@ -136,7 +178,7 @@ def get_model(params, device):
# Determine attn_implementation and torch_dtype based on use_flash_attn # Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else: else:
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype = torch.float16 torch_dtype = torch.float16
@ -162,7 +204,7 @@ def get_model(params, device):
codec_lm = AutoModelForCausalLM.from_config( codec_lm = AutoModelForCausalLM.from_config(
config=config, config=config,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch_dtype torch_dtype=torch_dtype,
) )
# cosyvoice2_token_size = 6561 # cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.resize_token_embeddings(codec_vocab_size)
@ -197,7 +239,7 @@ def get_model(params, device):
llm, llm,
encoder_projector, encoder_projector,
codec_lm, codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right", codec_lm_padding_side="left" if params.use_flash_attn else "right",
) )
if params.avg > 1: if params.avg > 1:
@ -325,6 +367,12 @@ def get_parser():
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument(
"--token2wav-path",
type=str,
default="/workspace/CosyVoice-300M-SFT",
help="The path to the token2wav model",
)
# parser.add_argument( # parser.add_argument(
# "--dataset", # "--dataset",
# type=str, # type=str,
@ -350,6 +398,7 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
token2wav_model: nn.Module,
batch: dict, batch: dict,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -431,26 +480,32 @@ def decode_one_batch(
# {"role": "assistant", "content": ""}, # {"role": "assistant", "content": ""},
# ] # ]
# ] * len(feature) # ] * len(feature)
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] questions_with_history = [
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history] cut.custom["question"] for cut in batch["supervisions"]["cut"]
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history] ]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
messages = [] messages = []
for i, total_round in enumerate(chat_rounds): for i, total_round in enumerate(chat_rounds):
message = [] message = []
if total_round > 1: if total_round > 1:
history_question_answer = history_contexts[i].split('USER:') history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item] history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1): for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:') question_answer = history_question_answer[j].split("ASSISTANT:")
message += [ message += [
{"role": "user", "content": question_answer[0].strip()}, {"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()} {"role": "assistant", "content": question_answer[1].strip()},
] ]
message += [ message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "user", "content": f"{last_questions[i]}"}, # {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""} {"role": "assistant", "content": ""},
] ]
print(f"message: {message}, batch_size {len(chat_rounds)}") print(f"message: {message}, batch_size {len(chat_rounds)}")
messages.append(message) messages.append(message)
@ -461,16 +516,21 @@ def decode_one_batch(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
) )
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
for cut_id in cut_ids: generated_speech_output = [
speech_token_file_name = ( generated_speech_output
params.log_dir / f"{cut_id}.txt" ] # WAR: only support batch = 1 for now
) for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
with open(speech_token_file_name, 'w') as f: speech_file_name = params.log_dir / f"{cut_id}.wav"
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav" audio_tokens = [token for token in audio_tokens if token < 4096]
#torchaudio.save(save_path, speech_output.cpu(), 16000) audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
# print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
save_str = " ".join([str(i) for i in generated_speech_output]) sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
f.write(f"{cut_id}|{save_str}\n") # with open(speech_token_file_name, 'w') as f:
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
# save_str = " ".join([str(i) for i in generated_speech_output])
# f.write(f"{cut_id}|{save_str}\n")
else: else:
generated_ids = model.decode( generated_ids = model.decode(
@ -486,6 +546,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
token2wav_model: nn.Module,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -548,14 +609,23 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
answers = batch["supervisions"]["text"] answers = batch["supervisions"]["text"]
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] questions_with_history = [
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] cut.custom["question"] for cut in batch["supervisions"]["cut"]
texts = [question.split('<USER>: ')[-1].strip() for question in questions_with_history] ]
answer_cosyvoice_speech_token = [
cut.custom["answer_cosyvoice_speech_token"]
for cut in batch["supervisions"]["cut"]
]
texts = [
question.split("<USER>: ")[-1].strip()
for question in questions_with_history
]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
token2wav_model=token2wav_model,
batch=batch, batch=batch,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
@ -643,9 +713,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.log_dir = Path(params.exp_dir) / f"log-{params.method}" params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
params.log_dir.mkdir(parents=True, exist_ok=True) params.log_dir.mkdir(parents=True, exist_ok=True)
setup_logger( setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
)
logging.info("Decoding started") logging.info("Decoding started")
logging.info(params) logging.info(params)
@ -657,6 +725,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
model, tokenizer = get_model(params, device) model, tokenizer = get_model(params, device)
token2wav_model = CosyVoice(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -697,6 +768,7 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
token2wav_model=token2wav_model,
tokenizer=tokenizer, tokenizer=tokenizer,
) )

View File

@ -66,8 +66,9 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
# from multi_dataset import MultiDataset # from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import ( from transformers import (
@ -146,6 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to enable speech codec output.", help="Whether to enable speech codec output.",
) )
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -332,9 +334,7 @@ def compute_loss(
# remove too long text # remove too long text
# texts = [ text for text in texts if len(text) < 1024 ] # texts = [ text for text in texts if len(text) < 1024 ]
if len(texts) != len(messages): if len(texts) != len(messages):
logging.warning( logging.warning(f"Remove too long text, {messages} ")
f"Remove too long text, {messages} "
)
max_len_texts = max([len(text) for text in texts]) max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right": if tokenizer.padding_side == "right":
texts = [ texts = [
@ -354,10 +354,10 @@ def compute_loss(
# first get the indices of the tokens # first get the indices of the tokens
mask_prompt = True mask_prompt = True
if mask_prompt: if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) default_speech_token_id = tokenizer.convert_tokens_to_ids(
mask_indices = torch.where( DEFAULT_SPEECH_TOKEN
input_ids == default_speech_token_id
) )
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)): for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i] row = mask_indices[0][i]
col = mask_indices[1][i] col = mask_indices[1][i]
@ -382,30 +382,39 @@ def compute_loss(
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
answers = batch["supervisions"]["text"] answers = batch["supervisions"]["text"]
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] answer_cosyvoice_speech_token = [
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history] cut.custom["answer_cosyvoice_speech_token"]
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history] for cut in batch["supervisions"]["cut"]
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉 ]
last_questions = [
question.split("<USER>: ")[-1].strip() for question in questions_with_history
]
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。 # <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
messages = [] messages = []
for i, total_round in enumerate(chat_rounds): for i, total_round in enumerate(chat_rounds):
message = [] message = []
if total_round > 1: if total_round > 1:
history_question_answer = history_contexts[i].split('USER:') history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item] history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1): for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:') question_answer = history_question_answer[j].split("ASSISTANT:")
message += [ message += [
{"role": "user", "content": question_answer[0].strip()}, {"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()} {"role": "assistant", "content": question_answer[1].strip()},
] ]
message += [ message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]} {"role": "assistant", "content": answers[i]},
] ]
messages.append(message) messages.append(message)
@ -423,7 +432,13 @@ def compute_loss(
labels=target_ids.to(device), labels=target_ids.to(device),
) )
else: else:
text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output( (
text_loss,
acc,
codec_loss,
codec_acc,
codec_topk_acc,
) = model.forward_with_speech_output(
fbank=feature, fbank=feature,
input_ids=input_ids.to(device), input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device), attention_mask=attention_mask.to(device),
@ -445,12 +460,8 @@ def compute_loss(
acc * info["frames"] acc * info["frames"]
) # WAR: to avoid normalization by the number of frames ) # WAR: to avoid normalization by the number of frames
if params.enable_speech_output: if params.enable_speech_output:
info["codec_acc"] = ( info["codec_acc"] = codec_acc * info["frames"]
codec_acc * info["frames"] info["codec_topk_acc"] = codec_topk_acc * info["frames"]
)
info["codec_topk_acc"] = (
codec_topk_acc * info["frames"]
)
info["codec_loss"] = codec_loss.detach().cpu().item() info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item() info["text_loss"] = text_loss.detach().cpu().item()
return loss, info return loss, info
@ -469,7 +480,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast('cuda', enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -584,7 +595,7 @@ def train_one_epoch(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
) )
try: try:
with torch.amp.autocast('cuda', enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -722,7 +733,6 @@ def run(rank, world_size, args):
# model.resize_token_embeddings(len(tokenizer)) # model.resize_token_embeddings(len(tokenizer))
# model.vocab_size = len(tokenizer) # model.vocab_size = len(tokenizer)
llm.config.pad_token_id = tokenizer.pad_token_id llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN DEFAULT_SPEECH_TOKEN
@ -736,12 +746,11 @@ def run(rank, world_size, args):
param.requires_grad = False param.requires_grad = False
encoder_projector.eval() encoder_projector.eval()
if params.enable_speech_output: if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn # Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else: else:
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype = torch.float16 torch_dtype = torch.float16
@ -766,9 +775,9 @@ def run(rank, world_size, args):
# Pass attn_implementation and torch_dtype to the constructor # Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality # Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config( codec_lm = AutoModelForCausalLM.from_config(
config=config, config=config,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch_dtype torch_dtype=torch_dtype,
) )
# cosyvoice2_token_size = 6561 # cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.resize_token_embeddings(codec_vocab_size)
@ -803,7 +812,7 @@ def run(rank, world_size, args):
llm, llm,
encoder_projector, encoder_projector,
codec_lm, codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right", codec_lm_padding_side="left" if params.use_flash_attn else "right",
) )
if params.pretrained_model_path: if params.pretrained_model_path:
@ -851,12 +860,11 @@ def run(rank, world_size, args):
codec_len = len(c.custom["answer_cosyvoice_speech_token"]) codec_len = len(c.custom["answer_cosyvoice_speech_token"])
if codec_len > 2200: if codec_len > 2200:
logging.warning( logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
) )
return False return False
return True return True
train_cuts = data_module.train_cuts() train_cuts = data_module.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)