mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update hf dataset loading into lhotse
This commit is contained in:
parent
d742043e75
commit
448a4eeea7
@ -2,6 +2,7 @@
|
|||||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||||
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
|
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
|
||||||
|
# Copyright 2025 Nvidia (Yuekai Zhang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -23,12 +24,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from lhotse import (
|
from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
|
||||||
CutSet,
|
|
||||||
LilcomChunkyWriter,
|
|
||||||
WhisperFbank,
|
|
||||||
WhisperFbankConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
@ -93,7 +89,12 @@ def get_parser():
|
|||||||
default="answer",
|
default="answer",
|
||||||
help="The key in the Huggingface dataset containing the text data",
|
help="The key in the Huggingface dataset containing the text data",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefix",
|
||||||
|
type=str,
|
||||||
|
default="belle",
|
||||||
|
help="""The dataset prefix to use when saving the features""",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -114,27 +115,28 @@ def compute_fbank(args):
|
|||||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
raise NotImplementedError("Only WhisperFbank is implemented.")
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
start = 0
|
dataset = load_dataset(
|
||||||
stop = 1601
|
args.huggingface_dataset_path_or_name, streaming=True, split="train"
|
||||||
|
)
|
||||||
|
num_shards = dataset.num_shards
|
||||||
num_digits = 5
|
num_digits = 5
|
||||||
for i in range(start, stop):
|
for i in range(num_shards):
|
||||||
|
shard = dataset.shard(num_shards, i)
|
||||||
|
shard = shard.take(10) # for testing
|
||||||
|
logging.info(
|
||||||
|
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
|
||||||
|
)
|
||||||
|
|
||||||
idx = f"{i}".zfill(num_digits)
|
idx = f"{i}".zfill(num_digits)
|
||||||
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
|
||||||
parquet_files = [
|
|
||||||
f"data/train-{idx}-of-01601.parquet",
|
|
||||||
]
|
|
||||||
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
|
||||||
file_name = parquet_files[0]
|
|
||||||
logging.info(f"Loading dataset from {file_name}")
|
|
||||||
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
|
|
||||||
|
|
||||||
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key)
|
cut_set = CutSet.from_huggingface_dataset(
|
||||||
|
shard, audio_key=args.audio_key, text_key=args.text_key
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Splitting cuts into smaller chunks")
|
|
||||||
cut_set = cut_set.trim_to_supervisions(
|
cut_set = cut_set.trim_to_supervisions(
|
||||||
keep_overlapping=False, min_duration=None
|
keep_overlapping=False, min_duration=None
|
||||||
)
|
)
|
||||||
@ -153,22 +155,13 @@ def compute_fbank(args):
|
|||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
cuts_path = f"{in_out_dir}/cuts_belle.{idx}.jsonl.gz"
|
cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz"
|
||||||
logging.info(f"Saving to {cuts_path}")
|
logging.info(f"Saving to {cuts_path}")
|
||||||
# cut_set.to_file(cuts_path)
|
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||||
remove_recording_item(cut_set, cuts_path)
|
cut_set.drop_recordings().to_file(cuts_path)
|
||||||
|
if i > 1:
|
||||||
|
break
|
||||||
|
|
||||||
def remove_recording_item(
|
|
||||||
cuts,
|
|
||||||
output_cuts,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
don't store recording item
|
|
||||||
"""
|
|
||||||
with CutSet.open_writer(output_cuts) as writer:
|
|
||||||
for cut in cuts:
|
|
||||||
cut.recording.sources = None
|
|
||||||
writer.write(cut)
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
@ -20,10 +20,10 @@ log() {
|
|||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "stage 0: "
|
log "stage 0: "
|
||||||
pip uninstall lhotse
|
#pip uninstall lhotse
|
||||||
cd /workspace/slam/lhotse
|
#cd /workspace/slam/lhotse
|
||||||
git config --global --add safe.directory /workspace/slam/lhotse
|
#git config --global --add safe.directory /workspace/slam/lhotse
|
||||||
pip install -e '.[dev]'
|
#pip install -e '.[dev]'
|
||||||
cd -
|
cd -
|
||||||
pip install -r slam_omni/requirements.txt
|
pip install -r slam_omni/requirements.txt
|
||||||
fi
|
fi
|
||||||
@ -31,7 +31,12 @@ fi
|
|||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
|
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
|
||||||
|
|
||||||
python3 local/compute_whisper_fbank.py
|
python3 local/compute_whisper_fbank.py \
|
||||||
|
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||||
|
--out-dir data/fbank_test \
|
||||||
|
--huggingface-dataset-path-or-name /workspace/Belle_1.4M-SLAM-Omni \
|
||||||
|
--audio-key question_audio --text-key answer \
|
||||||
|
--prefix belle
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@ -52,16 +57,18 @@ fi
|
|||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "stage 3: "
|
log "stage 3: "
|
||||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||||
python3 ./slam_omni/decode.py \
|
python3 ./slam_omni/decode.py \
|
||||||
--max-duration 1 \
|
--max-duration 1 \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--epoch 997 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--method small_test_speech2speech_rerun \
|
--method e2e-epoch10_speech2speech_rerun \
|
||||||
--enable-speech-output True \
|
--enable-speech-output True \
|
||||||
|
--token2wav-path /workspace/CosyVoice-300M-SFT \
|
||||||
--use-lora True # --on-the-fly-feats True
|
--use-lora True # --on-the-fly-feats True
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
@ -24,7 +24,14 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, load_manifest, load_manifest_lazy
|
from datasets import load_dataset
|
||||||
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
load_manifest,
|
||||||
|
load_manifest_lazy,
|
||||||
|
)
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
@ -38,11 +45,11 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|||||||
OnTheFlyFeatures,
|
OnTheFlyFeatures,
|
||||||
)
|
)
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from speech_dataset import K2SpeechRecognitionDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
from speech_dataset import K2SpeechRecognitionDataset
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
@ -310,7 +317,9 @@ class AsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||||
|
),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -365,7 +374,9 @@ class AsrDataModule:
|
|||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
|
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||||
|
)
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -390,7 +401,9 @@ class AsrDataModule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu')))
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cpu"))
|
||||||
|
)
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -419,16 +432,27 @@ class AsrDataModule:
|
|||||||
parquet_files = [
|
parquet_files = [
|
||||||
f"data/train-{idx}-of-01601.parquet",
|
f"data/train-{idx}-of-01601.parquet",
|
||||||
]
|
]
|
||||||
parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
parquet_files = [
|
||||||
|
f"{self.args.huggingface_dataset_path_or_name}/{f}"
|
||||||
|
for f in parquet_files
|
||||||
|
]
|
||||||
file_name = parquet_files[0]
|
file_name = parquet_files[0]
|
||||||
logging.info(f"Loading dataset from {file_name}")
|
logging.info(f"Loading dataset from {file_name}")
|
||||||
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
|
dataset = load_dataset(
|
||||||
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key)
|
"parquet", data_files=parquet_files, streaming=True, split="train"
|
||||||
|
)
|
||||||
|
cut_set = CutSet.from_huggingface_dataset(
|
||||||
|
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||||
|
)
|
||||||
if self.args.resample_to_16kHz:
|
if self.args.resample_to_16kHz:
|
||||||
cut_set = cut_set.resample(16000)
|
cut_set = cut_set.resample(16000)
|
||||||
return {'test':cut_set}
|
return {"test": cut_set}
|
||||||
else:
|
else:
|
||||||
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
||||||
|
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
|
||||||
|
return {
|
||||||
|
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz")
|
||||||
|
}
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_cuts(self) -> CutSet:
|
def dev_cuts(self) -> CutSet:
|
||||||
@ -436,8 +460,9 @@ class AsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
|
@ -48,32 +48,74 @@ python3 ./whisper_llm_zh/decode.py \
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
import whisper
|
import whisper
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||||
from data_module import AsrDataModule
|
from data_module import AsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from model import SPEECH_LLM, EncoderProjector
|
from model import SPEECH_LLM, EncoderProjector
|
||||||
|
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from train import DEFAULT_SPEECH_TOKEN
|
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from train import add_model_arguments
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
average_checkpoints,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
|
||||||
|
def audio_decode_cosyvoice(audio_tokens, codec_decoder):
|
||||||
|
"""
|
||||||
|
Generate audio from tokens with optional tone and prompt embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_tokens (list): List of audio tokens to be processed.
|
||||||
|
codec_decoder: Codec decoder for generating audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Generated audio waveform.
|
||||||
|
"""
|
||||||
|
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
|
||||||
|
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
|
||||||
|
prompt_speech_feat = torch.zeros(1, 0, 80)
|
||||||
|
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||||
|
token=audio_tokens.to(codec_decoder.model.device),
|
||||||
|
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
||||||
|
prompt_token_len=torch.tensor(
|
||||||
|
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
||||||
|
prompt_feat_len=torch.tensor(
|
||||||
|
[prompt_speech_feat.shape[1]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
embedding=flow_embedding.to(codec_decoder.model.device),
|
||||||
|
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||||
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_hat
|
||||||
|
|
||||||
|
|
||||||
def get_model(params, device):
|
def get_model(params, device):
|
||||||
"""Load and prepare the speech-to-speech model."""
|
"""Load and prepare the speech-to-speech model."""
|
||||||
if params.remove_whisper_encoder_input_length_restriction:
|
if params.remove_whisper_encoder_input_length_restriction:
|
||||||
@ -136,7 +178,7 @@ def get_model(params, device):
|
|||||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
@ -162,7 +204,7 @@ def get_model(params, device):
|
|||||||
codec_lm = AutoModelForCausalLM.from_config(
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
config=config,
|
config=config,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
# cosyvoice2_token_size = 6561
|
# cosyvoice2_token_size = 6561
|
||||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
@ -197,7 +239,7 @@ def get_model(params, device):
|
|||||||
llm,
|
llm,
|
||||||
encoder_projector,
|
encoder_projector,
|
||||||
codec_lm,
|
codec_lm,
|
||||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
@ -325,6 +367,12 @@ def get_parser():
|
|||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--token2wav-path",
|
||||||
|
type=str,
|
||||||
|
default="/workspace/CosyVoice-300M-SFT",
|
||||||
|
help="The path to the token2wav model",
|
||||||
|
)
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# "--dataset",
|
# "--dataset",
|
||||||
# type=str,
|
# type=str,
|
||||||
@ -350,6 +398,7 @@ def decode_one_batch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
|
token2wav_model: nn.Module,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
@ -431,26 +480,32 @@ def decode_one_batch(
|
|||||||
# {"role": "assistant", "content": ""},
|
# {"role": "assistant", "content": ""},
|
||||||
# ]
|
# ]
|
||||||
# ] * len(feature)
|
# ] * len(feature)
|
||||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
questions_with_history = [
|
||||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
]
|
||||||
|
history_contexts = [
|
||||||
|
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
||||||
|
]
|
||||||
|
last_questions = [
|
||||||
|
question.split("<USER>: ")[-1].strip() for question in questions_with_history
|
||||||
|
]
|
||||||
messages = []
|
messages = []
|
||||||
for i, total_round in enumerate(chat_rounds):
|
for i, total_round in enumerate(chat_rounds):
|
||||||
message = []
|
message = []
|
||||||
if total_round > 1:
|
if total_round > 1:
|
||||||
history_question_answer = history_contexts[i].split('USER:')
|
history_question_answer = history_contexts[i].split("USER:")
|
||||||
history_question_answer = [item for item in history_question_answer if item]
|
history_question_answer = [item for item in history_question_answer if item]
|
||||||
for j in range(total_round - 1):
|
for j in range(total_round - 1):
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||||
message += [
|
message += [
|
||||||
{"role": "user", "content": question_answer[0].strip()},
|
{"role": "user", "content": question_answer[0].strip()},
|
||||||
{"role": "assistant", "content": question_answer[1].strip()}
|
{"role": "assistant", "content": question_answer[1].strip()},
|
||||||
]
|
]
|
||||||
message += [
|
message += [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||||
# {"role": "user", "content": f"{last_questions[i]}"},
|
# {"role": "user", "content": f"{last_questions[i]}"},
|
||||||
{"role": "assistant", "content": ""}
|
{"role": "assistant", "content": ""},
|
||||||
]
|
]
|
||||||
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
@ -461,16 +516,21 @@ def decode_one_batch(
|
|||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
)
|
)
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
for cut_id in cut_ids:
|
generated_speech_output = [
|
||||||
speech_token_file_name = (
|
generated_speech_output
|
||||||
params.log_dir / f"{cut_id}.txt"
|
] # WAR: only support batch = 1 for now
|
||||||
)
|
for cut_id, audio_tokens in zip(cut_ids, generated_speech_output):
|
||||||
with open(speech_token_file_name, 'w') as f:
|
speech_file_name = params.log_dir / f"{cut_id}.wav"
|
||||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
audio_tokens = [token for token in audio_tokens if token < 4096]
|
||||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||||
# print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||||
save_str = " ".join([str(i) for i in generated_speech_output])
|
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
|
||||||
f.write(f"{cut_id}|{save_str}\n")
|
# with open(speech_token_file_name, 'w') as f:
|
||||||
|
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||||
|
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||||
|
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||||
|
# save_str = " ".join([str(i) for i in generated_speech_output])
|
||||||
|
# f.write(f"{cut_id}|{save_str}\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
generated_ids = model.decode(
|
generated_ids = model.decode(
|
||||||
@ -486,6 +546,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
|
token2wav_model: nn.Module,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -548,14 +609,23 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
answers = batch["supervisions"]["text"]
|
answers = batch["supervisions"]["text"]
|
||||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
questions_with_history = [
|
||||||
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||||
texts = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
]
|
||||||
|
answer_cosyvoice_speech_token = [
|
||||||
|
cut.custom["answer_cosyvoice_speech_token"]
|
||||||
|
for cut in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
|
texts = [
|
||||||
|
question.split("<USER>: ")[-1].strip()
|
||||||
|
for question in questions_with_history
|
||||||
|
]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
token2wav_model=token2wav_model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
@ -643,9 +713,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
|
params.log_dir = Path(params.exp_dir) / f"log-{params.method}"
|
||||||
params.log_dir.mkdir(parents=True, exist_ok=True)
|
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
setup_logger(
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
||||||
f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -657,6 +725,9 @@ def main():
|
|||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
model, tokenizer = get_model(params, device)
|
model, tokenizer = get_model(params, device)
|
||||||
|
token2wav_model = CosyVoice(
|
||||||
|
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||||
|
)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -697,6 +768,7 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
token2wav_model=token2wav_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -66,8 +66,9 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||||
|
|
||||||
# from multi_dataset import MultiDataset
|
# from multi_dataset import MultiDataset
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
from peft import LoraConfig, get_peft_model
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -146,6 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Whether to enable speech codec output.",
|
help="Whether to enable speech codec output.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -332,9 +334,7 @@ def compute_loss(
|
|||||||
# remove too long text
|
# remove too long text
|
||||||
# texts = [ text for text in texts if len(text) < 1024 ]
|
# texts = [ text for text in texts if len(text) < 1024 ]
|
||||||
if len(texts) != len(messages):
|
if len(texts) != len(messages):
|
||||||
logging.warning(
|
logging.warning(f"Remove too long text, {messages} ")
|
||||||
f"Remove too long text, {messages} "
|
|
||||||
)
|
|
||||||
max_len_texts = max([len(text) for text in texts])
|
max_len_texts = max([len(text) for text in texts])
|
||||||
if tokenizer.padding_side == "right":
|
if tokenizer.padding_side == "right":
|
||||||
texts = [
|
texts = [
|
||||||
@ -354,10 +354,10 @@ def compute_loss(
|
|||||||
# first get the indices of the tokens
|
# first get the indices of the tokens
|
||||||
mask_prompt = True
|
mask_prompt = True
|
||||||
if mask_prompt:
|
if mask_prompt:
|
||||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||||
mask_indices = torch.where(
|
DEFAULT_SPEECH_TOKEN
|
||||||
input_ids == default_speech_token_id
|
|
||||||
)
|
)
|
||||||
|
mask_indices = torch.where(input_ids == default_speech_token_id)
|
||||||
for i in range(mask_indices[0].size(0)):
|
for i in range(mask_indices[0].size(0)):
|
||||||
row = mask_indices[0][i]
|
row = mask_indices[0][i]
|
||||||
col = mask_indices[1][i]
|
col = mask_indices[1][i]
|
||||||
@ -382,11 +382,20 @@ def compute_loss(
|
|||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
|
|
||||||
answers = batch["supervisions"]["text"]
|
answers = batch["supervisions"]["text"]
|
||||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
questions_with_history = [
|
||||||
|
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||||
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
answer_cosyvoice_speech_token = [
|
||||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
cut.custom["answer_cosyvoice_speech_token"]
|
||||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
for cut in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
|
last_questions = [
|
||||||
|
question.split("<USER>: ")[-1].strip() for question in questions_with_history
|
||||||
|
]
|
||||||
|
history_contexts = [
|
||||||
|
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
||||||
|
]
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
||||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
|
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
|
||||||
|
|
||||||
@ -394,18 +403,18 @@ def compute_loss(
|
|||||||
for i, total_round in enumerate(chat_rounds):
|
for i, total_round in enumerate(chat_rounds):
|
||||||
message = []
|
message = []
|
||||||
if total_round > 1:
|
if total_round > 1:
|
||||||
history_question_answer = history_contexts[i].split('USER:')
|
history_question_answer = history_contexts[i].split("USER:")
|
||||||
history_question_answer = [item for item in history_question_answer if item]
|
history_question_answer = [item for item in history_question_answer if item]
|
||||||
for j in range(total_round - 1):
|
for j in range(total_round - 1):
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||||
message += [
|
message += [
|
||||||
{"role": "user", "content": question_answer[0].strip()},
|
{"role": "user", "content": question_answer[0].strip()},
|
||||||
{"role": "assistant", "content": question_answer[1].strip()}
|
{"role": "assistant", "content": question_answer[1].strip()},
|
||||||
]
|
]
|
||||||
message += [
|
message += [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||||
{"role": "assistant", "content": answers[i]}
|
{"role": "assistant", "content": answers[i]},
|
||||||
]
|
]
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
@ -423,7 +432,13 @@ def compute_loss(
|
|||||||
labels=target_ids.to(device),
|
labels=target_ids.to(device),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output(
|
(
|
||||||
|
text_loss,
|
||||||
|
acc,
|
||||||
|
codec_loss,
|
||||||
|
codec_acc,
|
||||||
|
codec_topk_acc,
|
||||||
|
) = model.forward_with_speech_output(
|
||||||
fbank=feature,
|
fbank=feature,
|
||||||
input_ids=input_ids.to(device),
|
input_ids=input_ids.to(device),
|
||||||
attention_mask=attention_mask.to(device),
|
attention_mask=attention_mask.to(device),
|
||||||
@ -445,12 +460,8 @@ def compute_loss(
|
|||||||
acc * info["frames"]
|
acc * info["frames"]
|
||||||
) # WAR: to avoid normalization by the number of frames
|
) # WAR: to avoid normalization by the number of frames
|
||||||
if params.enable_speech_output:
|
if params.enable_speech_output:
|
||||||
info["codec_acc"] = (
|
info["codec_acc"] = codec_acc * info["frames"]
|
||||||
codec_acc * info["frames"]
|
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
|
||||||
)
|
|
||||||
info["codec_topk_acc"] = (
|
|
||||||
codec_topk_acc * info["frames"]
|
|
||||||
)
|
|
||||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||||
info["text_loss"] = text_loss.detach().cpu().item()
|
info["text_loss"] = text_loss.detach().cpu().item()
|
||||||
return loss, info
|
return loss, info
|
||||||
@ -469,7 +480,7 @@ def compute_validation_loss(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
with torch.amp.autocast('cuda', enabled=params.use_fp16):
|
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -584,7 +595,7 @@ def train_one_epoch(
|
|||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with torch.amp.autocast('cuda', enabled=params.use_fp16):
|
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -722,7 +733,6 @@ def run(rank, world_size, args):
|
|||||||
# model.resize_token_embeddings(len(tokenizer))
|
# model.resize_token_embeddings(len(tokenizer))
|
||||||
# model.vocab_size = len(tokenizer)
|
# model.vocab_size = len(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||||
DEFAULT_SPEECH_TOKEN
|
DEFAULT_SPEECH_TOKEN
|
||||||
@ -736,12 +746,11 @@ def run(rank, world_size, args):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
encoder_projector.eval()
|
encoder_projector.eval()
|
||||||
|
|
||||||
|
|
||||||
if params.enable_speech_output:
|
if params.enable_speech_output:
|
||||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
@ -768,7 +777,7 @@ def run(rank, world_size, args):
|
|||||||
codec_lm = AutoModelForCausalLM.from_config(
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
config=config,
|
config=config,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
# cosyvoice2_token_size = 6561
|
# cosyvoice2_token_size = 6561
|
||||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
@ -803,7 +812,7 @@ def run(rank, world_size, args):
|
|||||||
llm,
|
llm,
|
||||||
encoder_projector,
|
encoder_projector,
|
||||||
codec_lm,
|
codec_lm,
|
||||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path:
|
||||||
@ -851,12 +860,11 @@ def run(rank, world_size, args):
|
|||||||
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
||||||
if codec_len > 2200:
|
if codec_len > 2200:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
train_cuts = data_module.train_cuts()
|
train_cuts = data_module.train_cuts()
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user