mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update hf dataset loading into lhotse
This commit is contained in:
parent
d742043e75
commit
448a4eeea7
@ -2,6 +2,7 @@
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
|
||||
# Copyright 2025 Nvidia (Yuekai Zhang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -23,12 +24,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
@ -93,7 +89,12 @@ def get_parser():
|
||||
default="answer",
|
||||
help="The key in the Huggingface dataset containing the text data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="belle",
|
||||
help="""The dataset prefix to use when saving the features""",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -114,27 +115,28 @@ def compute_fbank(args):
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
raise NotImplementedError("Only WhisperFbank is implemented.")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
start = 0
|
||||
stop = 1601
|
||||
dataset = load_dataset(
|
||||
args.huggingface_dataset_path_or_name, streaming=True, split="train"
|
||||
)
|
||||
num_shards = dataset.num_shards
|
||||
num_digits = 5
|
||||
for i in range(start, stop):
|
||||
for i in range(num_shards):
|
||||
shard = dataset.shard(num_shards, i)
|
||||
shard = shard.take(10) # for testing
|
||||
logging.info(
|
||||
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
|
||||
)
|
||||
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
||||
parquet_files = [
|
||||
f"data/train-{idx}-of-01601.parquet",
|
||||
]
|
||||
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
||||
file_name = parquet_files[0]
|
||||
logging.info(f"Loading dataset from {file_name}")
|
||||
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
|
||||
|
||||
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key)
|
||||
cut_set = CutSet.from_huggingface_dataset(
|
||||
shard, audio_key=args.audio_key, text_key=args.text_key
|
||||
)
|
||||
|
||||
logging.info("Splitting cuts into smaller chunks")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
@ -153,22 +155,13 @@ def compute_fbank(args):
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cuts_path = f"{in_out_dir}/cuts_belle.{idx}.jsonl.gz"
|
||||
cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz"
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
# cut_set.to_file(cuts_path)
|
||||
remove_recording_item(cut_set, cuts_path)
|
||||
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||
cut_set.drop_recordings().to_file(cuts_path)
|
||||
if i > 1:
|
||||
break
|
||||
|
||||
def remove_recording_item(
|
||||
cuts,
|
||||
output_cuts,
|
||||
):
|
||||
"""
|
||||
don't store recording item
|
||||
"""
|
||||
with CutSet.open_writer(output_cuts) as writer:
|
||||
for cut in cuts:
|
||||
cut.recording.sources = None
|
||||
writer.write(cut)
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
@ -20,10 +20,10 @@ log() {
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: "
|
||||
pip uninstall lhotse
|
||||
cd /workspace/slam/lhotse
|
||||
git config --global --add safe.directory /workspace/slam/lhotse
|
||||
pip install -e '.[dev]'
|
||||
#pip uninstall lhotse
|
||||
#cd /workspace/slam/lhotse
|
||||
#git config --global --add safe.directory /workspace/slam/lhotse
|
||||
#pip install -e '.[dev]'
|
||||
cd -
|
||||
pip install -r slam_omni/requirements.txt
|
||||
fi
|
||||
@ -31,7 +31,12 @@ fi
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
|
||||
|
||||
python3 local/compute_whisper_fbank.py
|
||||
python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank_test \
|
||||
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
|
||||
--audio-key question_audio --text-key answer \
|
||||
--prefix belle
|
||||
fi
|
||||
|
||||
|
||||
@ -52,16 +57,18 @@ fi
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: "
|
||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
python3 ./slam_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 997 --avg 1 \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method small_test_speech2speech_rerun \
|
||||
--method e2e-epoch10_speech2speech_rerun \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path /workspace/CosyVoice-300M-SFT \
|
||||
--use-lora True # --on-the-fly-feats True
|
||||
|
||||
fi
|
||||
|
@ -24,7 +24,14 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, load_manifest, load_manifest_lazy
|
||||
from datasets import load_dataset
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
)
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
@ -38,11 +45,11 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from speech_dataset import K2SpeechRecognitionDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from speech_dataset import K2SpeechRecognitionDataset
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
@ -310,7 +317,9 @@ class AsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -365,7 +374,9 @@ class AsrDataModule:
|
||||
logging.info("About to create dev dataset")
|
||||
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
@ -390,7 +401,9 @@ class AsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu')))
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
@ -419,16 +432,27 @@ class AsrDataModule:
|
||||
parquet_files = [
|
||||
f"data/train-{idx}-of-01601.parquet",
|
||||
]
|
||||
parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
||||
parquet_files = [
|
||||
f"{self.args.huggingface_dataset_path_or_name}/{f}"
|
||||
for f in parquet_files
|
||||
]
|
||||
file_name = parquet_files[0]
|
||||
logging.info(f"Loading dataset from {file_name}")
|
||||
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
|
||||
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key)
|
||||
dataset = load_dataset(
|
||||
"parquet", data_files=parquet_files, streaming=True, split="train"
|
||||
)
|
||||
cut_set = CutSet.from_huggingface_dataset(
|
||||
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||
)
|
||||
if self.args.resample_to_16kHz:
|
||||
cut_set = cut_set.resample(16000)
|
||||
return {'test':cut_set}
|
||||
return {"test": cut_set}
|
||||
else:
|
||||
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
||||
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
||||
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
|
||||
return {
|
||||
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz")
|
||||
}
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
@ -436,8 +460,9 @@ class AsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
pass
|
||||
else:
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
|
||||
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
|
@ -48,32 +48,74 @@ python3 ./whisper_llm_zh/decode.py \
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from data_module import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from train import add_model_arguments
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
average_checkpoints,
|
||||
)
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
|
||||
Args:
|
||||
audio_tokens (list): List of audio tokens to be processed.
|
||||
codec_decoder: Codec decoder for generating audio.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Generated audio waveform.
|
||||
"""
|
||||
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
|
||||
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
|
||||
prompt_speech_feat = torch.zeros(1, 0, 80)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
||||
prompt_token_len=torch.tensor(
|
||||
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
||||
prompt_feat_len=torch.tensor(
|
||||
[prompt_speech_feat.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
embedding=flow_embedding.to(codec_decoder.model.device),
|
||||
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def get_model(params, device):
|
||||
"""Load and prepare the speech-to-speech model."""
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
@ -162,7 +204,7 @@ def get_model(params, device):
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# cosyvoice2_token_size = 6561
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
@ -325,6 +367,12 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
type=str,
|
||||
default="/workspace/CosyVoice-300M-SFT",
|
||||
help="The path to the token2wav model",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--dataset",
|
||||
# type=str,
|
||||
@ -350,6 +398,7 @@ def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: AutoTokenizer,
|
||||
token2wav_model: nn.Module,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -431,26 +480,32 @@ def decode_one_batch(
|
||||
# {"role": "assistant", "content": ""},
|
||||
# ]
|
||||
# ] * len(feature)
|
||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
questions_with_history = [
|
||||
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
history_contexts = [
|
||||
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
||||
]
|
||||
last_questions = [
|
||||
question.split("<USER>: ")[-1].strip() for question in questions_with_history
|
||||
]
|
||||
messages = []
|
||||
for i, total_round in enumerate(chat_rounds):
|
||||
message = []
|
||||
if total_round > 1:
|
||||
history_question_answer = history_contexts[i].split('USER:')
|
||||
history_question_answer = history_contexts[i].split("USER:")
|
||||
history_question_answer = [item for item in history_question_answer if item]
|
||||
for j in range(total_round - 1):
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||
message += [
|
||||
{"role": "user", "content": question_answer[0].strip()},
|
||||
{"role": "assistant", "content": question_answer[1].strip()}
|
||||
{"role": "assistant", "content": question_answer[1].strip()},
|
||||
]
|
||||
message += [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
# {"role": "user", "content": f"{last_questions[i]}"},
|
||||
{"role": "assistant", "content": ""}
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
||||
messages.append(message)
|
||||
@ -461,16 +516,21 @@ def decode_one_batch(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
for cut_id in cut_ids:
|
||||
speech_token_file_name = (
|
||||
params.log_dir / f"{cut_id}.txt"
|
||||
)
|
||||
with open(speech_token_file_name, 'w') as f:
|
||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
# print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||
save_str = " ".join([str(i) for i in generated_speech_output])
|
||||
f.write(f"{cut_id}|{save_str}\n")
|
||||
generated_speech_output = [
|
||||
generated_speech_output
|
||||
] # WAR: only support batch = 1 for now
|
||||
for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
|
||||
speech_file_name = params.log_dir / f"{cut_id}.wav"
|
||||
audio_tokens = [token for token in audio_tokens if token < 4096]
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
|
||||
# with open(speech_token_file_name, 'w') as f:
|
||||
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||
# save_str = " ".join([str(i) for i in generated_speech_output])
|
||||
# f.write(f"{cut_id}|{save_str}\n")
|
||||
|
||||
else:
|
||||
generated_ids = model.decode(
|
||||
@ -486,6 +546,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: AutoTokenizer,
|
||||
token2wav_model: nn.Module,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -548,14 +609,23 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
answers = batch["supervisions"]["text"]
|
||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
||||
texts = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
questions_with_history = [
|
||||
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
answer_cosyvoice_speech_token = [
|
||||
cut.custom["answer_cosyvoice_speech_token"]
|
||||
for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
texts = [
|
||||
question.split("<USER>: ")[-1].strip()
|
||||
for question in questions_with_history
|
||||
]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
token2wav_model=token2wav_model,
|
||||
batch=batch,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
@ -643,9 +713,7 @@ def main():
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
|
||||
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
|
||||
)
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
||||
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
@ -657,6 +725,9 @@ def main():
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model, tokenizer = get_model(params, device)
|
||||
token2wav_model = CosyVoice(
|
||||
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -697,6 +768,7 @@ def main():
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
token2wav_model=token2wav_model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
@ -66,8 +66,9 @@ from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
|
||||
# from multi_dataset import MultiDataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import (
|
||||
@ -146,6 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Whether to enable speech codec output.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -332,9 +334,7 @@ def compute_loss(
|
||||
# remove too long text
|
||||
# texts = [ text for text in texts if len(text) < 1024 ]
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(
|
||||
f"Remove too long text, {messages} "
|
||||
)
|
||||
logging.warning(f"Remove too long text, {messages} ")
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
@ -354,10 +354,10 @@ def compute_loss(
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
mask_indices = torch.where(
|
||||
input_ids == default_speech_token_id
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
@ -382,11 +382,20 @@ def compute_loss(
|
||||
batch_idx_train = params.batch_idx_train
|
||||
|
||||
answers = batch["supervisions"]["text"]
|
||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||
questions_with_history = [
|
||||
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
||||
answer_cosyvoice_speech_token = [
|
||||
cut.custom["answer_cosyvoice_speech_token"]
|
||||
for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
last_questions = [
|
||||
question.split("<USER>: ")[-1].strip() for question in questions_with_history
|
||||
]
|
||||
history_contexts = [
|
||||
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
||||
]
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
|
||||
|
||||
@ -394,18 +403,18 @@ def compute_loss(
|
||||
for i, total_round in enumerate(chat_rounds):
|
||||
message = []
|
||||
if total_round > 1:
|
||||
history_question_answer = history_contexts[i].split('USER:')
|
||||
history_question_answer = history_contexts[i].split("USER:")
|
||||
history_question_answer = [item for item in history_question_answer if item]
|
||||
for j in range(total_round - 1):
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||
message += [
|
||||
{"role": "user", "content": question_answer[0].strip()},
|
||||
{"role": "assistant", "content": question_answer[1].strip()}
|
||||
{"role": "assistant", "content": question_answer[1].strip()},
|
||||
]
|
||||
message += [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": answers[i]}
|
||||
{"role": "assistant", "content": answers[i]},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
@ -423,7 +432,13 @@ def compute_loss(
|
||||
labels=target_ids.to(device),
|
||||
)
|
||||
else:
|
||||
text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output(
|
||||
(
|
||||
text_loss,
|
||||
acc,
|
||||
codec_loss,
|
||||
codec_acc,
|
||||
codec_topk_acc,
|
||||
) = model.forward_with_speech_output(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
@ -445,12 +460,8 @@ def compute_loss(
|
||||
acc * info["frames"]
|
||||
) # WAR: to avoid normalization by the number of frames
|
||||
if params.enable_speech_output:
|
||||
info["codec_acc"] = (
|
||||
codec_acc * info["frames"]
|
||||
)
|
||||
info["codec_topk_acc"] = (
|
||||
codec_topk_acc * info["frames"]
|
||||
)
|
||||
info["codec_acc"] = codec_acc * info["frames"]
|
||||
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
|
||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||
info["text_loss"] = text_loss.detach().cpu().item()
|
||||
return loss, info
|
||||
@ -469,7 +480,7 @@ def compute_validation_loss(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.amp.autocast('cuda', enabled=params.use_fp16):
|
||||
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
@ -584,7 +595,7 @@ def train_one_epoch(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
try:
|
||||
with torch.amp.autocast('cuda', enabled=params.use_fp16):
|
||||
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
@ -722,7 +733,6 @@ def run(rank, world_size, args):
|
||||
# model.resize_token_embeddings(len(tokenizer))
|
||||
# model.vocab_size = len(tokenizer)
|
||||
|
||||
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
@ -736,7 +746,6 @@ def run(rank, world_size, args):
|
||||
param.requires_grad = False
|
||||
encoder_projector.eval()
|
||||
|
||||
|
||||
if params.enable_speech_output:
|
||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||
if params.use_flash_attn:
|
||||
@ -768,7 +777,7 @@ def run(rank, world_size, args):
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# cosyvoice2_token_size = 6561
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
@ -856,7 +865,6 @@ def run(rank, world_size, args):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
train_cuts = data_module.train_cuts()
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
Loading…
x
Reference in New Issue
Block a user