add speech continuation pretraining

This commit is contained in:
root 2025-05-15 14:16:51 +00:00
parent e65725810c
commit f81363d324
4 changed files with 391 additions and 95 deletions

View File

@ -122,7 +122,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 1: Compute fbank feature from huggingface"
log "stage 6: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_voice_assistant \
@ -161,10 +161,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
--subset test --split test \
--audio-key audio --text-key text \
--prefix gigaspeech
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "stage 9: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
--out-dir data/fbank_gigaspeech \
@ -195,7 +192,7 @@ fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "stage 11: Decoding EN, only support batch_size=1 for now."
log "stage 11: Decoding EN, val set only support batch_size=1 for now."
exp_dir=./qwen_omni/exp_speech2speech_en_continue
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
@ -256,3 +253,71 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
--output-dir test_result
done
fi
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
log "stage 15: Training Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 600 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
fi
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
log "stage 16: Training Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 1200 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
$train_cmd_args
fi

View File

@ -24,7 +24,7 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from datasets import load_dataset
from datasets import interleave_datasets, load_dataset
from lhotse import (
CutSet,
WhisperFbank,
@ -36,6 +36,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PerturbSpeed,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
@ -47,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader
from utils import str2bool
from utils import get_rank, str2bool
class _SeedWorkers:
@ -123,6 +124,14 @@ class AsrDataModule:
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--on-the-fly-speed-perturb",
type=str2bool,
default=True,
help="When enabled, use on-the-fly speed perturbation. "
"Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
@ -188,27 +197,27 @@ class AsrDataModule:
group.add_argument(
"--huggingface-dataset-path-or-name",
type=str,
default="/workspace/Belle_1.4M-SLAM-Omni",
default=None,
help="The path or name of the Huggingface dataset",
)
group.add_argument(
"--audio-key",
type=str,
default="question_audio",
default="audio",
help="The key in the Huggingface dataset containing the audio data",
)
group.add_argument(
"--text-key",
type=str,
default="answer",
default="text",
help="The key in the Huggingface dataset containing the text data",
)
group.add_argument(
"--resample-to-16kHz",
type=str2bool,
default=True,
help="Resample audio to 16kHz. Default: False.",
)
# group.add_argument(
# "--resample-to-16kHz",
# type=str2bool,
# default=True,
# help="Resample audio to 16kHz. Default: False.",
# )
def train_dataloaders(
self,
@ -232,6 +241,8 @@ class AsrDataModule:
)
else:
logging.info("Disable MUSAN")
if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats:
transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
input_transforms = []
if self.args.enable_spec_aug:
@ -260,9 +271,11 @@ class AsrDataModule:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
rank = get_rank()
train = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
@ -271,26 +284,6 @@ class AsrDataModule:
return_cuts=self.args.return_cuts,
)
# if self.args.on_the_fly_feats:
# # NOTE: the PerturbSpeed transform should be added only if we
# # remove it from data prep stage.
# # Add on-the-fly speed perturbation; since originally it would
# # have increased epoch size by 3, we will apply prob 2/3 and use
# # 3x more epochs.
# # Speed perturbation probably should come first before
# # concatenation, but in principle the transforms order doesn't have
# # to be strict (e.g. could be randomized)
# # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# # Drop feats to be on the safe side.
# train = K2SpeechRecognitionDataset(
# cut_transforms=transforms,
# input_strategy=OnTheFlyFeatures(
# WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
# ),
# input_transforms=input_transforms,
# return_cuts=self.args.return_cuts,
# )
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
@ -298,8 +291,7 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
buffer_size=self.args.num_buckets * 1000,
drop_last=self.args.drop_last,
)
else:
@ -339,10 +331,10 @@ class AsrDataModule:
CutSet for validation.
"""
logging.info("About to create dev dataset")
rank = get_rank()
validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
@ -470,25 +462,231 @@ class AsrDataModule:
)
return {"test": VoiceAssistant_cuts}
# def train_cuts_en_vocalnet(self) -> CutSet:
@lru_cache()
def train_cuts_ultravox(self) -> CutSet:
logging.info("About to get train cuts")
if self.args.huggingface_dataset_path_or_name is not None:
librispeech_path = (
self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
)
people_speech_path = (
self.args.huggingface_dataset_path_or_name + "/peoples_speech"
)
gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech"
else:
librispeech_path = "fixie-ai/librispeech_asr"
people_speech_path = "fixie-ai/peoples_speech"
gigaspeech_path = "fixie-ai/gigaspeech"
# 148_688
librispeech_other = load_dataset(
librispeech_path, "other", split="train.500", streaming=True
)
# 104_014
librispeech_clean_360 = load_dataset(
librispeech_path, "clean", split="train.360", streaming=True
)
# 28_539
librispeech_clean_100 = load_dataset(
librispeech_path, "clean", split="train.100", streaming=True
)
# 1_501_271
people_speech_clean = load_dataset(
people_speech_path, "clean", split="train", streaming=True
)
# 548_000
people_speech_dirty_sa = load_dataset(
people_speech_path, "dirty_sa", split="train", streaming=True
)
# 8_266_422
gigaspeech = load_dataset(
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
)
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_100,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_other_cuts = CutSet.from_huggingface_dataset(
librispeech_other,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_360,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
gigaspeech_cuts = CutSet.from_huggingface_dataset(
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
)
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
people_speech_clean,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
people_speech_dirty_sa,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
return CutSet.mux(
librispeech_clean_100_cuts,
librispeech_clean_360_cuts,
librispeech_other_cuts,
gigaspeech_cuts,
people_speech_clean_cuts,
people_speech_dirty_sa_cuts,
weights=[
28539,
104014,
148688,
8266422,
1501271,
548000,
],
)
# @lru_cache()
# def train_cuts_ultravox(self) -> CutSet:
# logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
# )
# return VoiceAssistant_cuts
# keep_columns = ["audio", "text", "continuation", "id"]
# librispeech_path="fixie-ai/librispeech_asr"
# # 148_688
# librispeech_other = load_dataset(librispeech_path, 'other', split='train.500', streaming=True)
# # 104_014
# librispeech_clean_360 = load_dataset(librispeech_path, 'clean', split='train.360', streaming=True)
# # 28_539
# librispeech_clean_100 = load_dataset(librispeech_path, 'clean', split='train.100', streaming=True)
# @lru_cache()
# def valid_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get valid cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
# )
# return VoiceAssistant_cuts
# cols_to_remove = librispeech_clean_100.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# librispeech_clean_100 = librispeech_clean_100.remove_columns(cols_to_remove)
# librispeech_clean_360 = librispeech_clean_360.remove_columns(cols_to_remove)
# librispeech_other = librispeech_other.remove_columns(cols_to_remove)
# people_speech_path="fixie-ai/peoples_speech"
# # 1_501_271
# people_speech_clean = load_dataset(people_speech_path, 'clean', split='train', streaming=True)
# # 548_000
# people_speech_dirty_sa = load_dataset(people_speech_path, 'dirty_sa', split='train', streaming=True)
# cols_to_remove = people_speech_clean.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# people_speech_clean = people_speech_clean.remove_columns(cols_to_remove)
# people_speech_dirty_sa = people_speech_dirty_sa.remove_columns(cols_to_remove)
# @lru_cache()
# def test_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get test cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
# # 8_266_422
# gigaspeech_path="fixie-ai/gigaspeech"
# gigaspeech = load_dataset(gigaspeech_path, 'xl-empty-audio-removed', split='train', streaming=True)
# # first rename segment_id to id
# gigaspeech = gigaspeech.rename_column("segment_id", "id")
# cols_to_remove = gigaspeech.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# gigaspeech = gigaspeech.remove_columns(cols_to_remove)
# total_item = 104014 + 28539 + 8266422 + 1501271 + 548000 + 148688
# final_datasets = interleave_datasets([
# librispeech_clean_100,
# librispeech_clean_360,
# gigaspeech,
# people_speech_clean,
# people_speech_dirty_sa,
# librispeech_other,
# ], probabilities=[
# 28539 / total_item,
# 104014 / total_item,
# 8266422 / total_item,
# 1501271 / total_item,
# 548000 / total_item,
# 148688 / total_item,
# ])
# train_cuts = CutSet.from_huggingface_dataset(
# final_datasets, audio_key=self.args.audio_key, text_key=self.args.text_key
# )
# return VoiceAssistant_cuts
# return train_cuts
@lru_cache()
def valid_cuts_ultravox(self) -> CutSet:
logging.info("About to get valid cuts")
librispeech_path = "fixie-ai/librispeech_asr"
librispeech_clean_valid = load_dataset(
librispeech_path, "clean", split="validation", streaming=True
)
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_valid,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
return librispeech_clean_valid_cuts
@lru_cache()
def train_cuts_librispeech(self) -> CutSet:
logging.info("About to get train cuts")
# librispeech_path="fixie-ai/librispeech_asr"
librispeech_path = "/workspace/slam/librispeech_asr"
# 148_688
librispeech_other = load_dataset(
librispeech_path, "other", split="train.500", streaming=True
)
# 104_014
librispeech_clean_360 = load_dataset(
librispeech_path, "clean", split="train.360", streaming=True
)
# 28_539
librispeech_clean_100 = load_dataset(
librispeech_path, "clean", split="train.100", streaming=True
)
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_100,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_other_cuts = CutSet.from_huggingface_dataset(
librispeech_other,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_360,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
return CutSet.mux(
librispeech_clean_100_cuts,
librispeech_clean_360_cuts,
librispeech_other_cuts,
weights=[
28539,
104014,
148688,
],
)
@lru_cache()
def train_cuts_gigaspeech(self) -> CutSet:
logging.info("About to get train cuts")
gigaspeech_path = "fixie-ai/gigaspeech"
gigaspeech = load_dataset(
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
)
gigaspeech_cuts = CutSet.from_huggingface_dataset(
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
)
return gigaspeech_cuts

View File

@ -10,3 +10,4 @@ transformers>=4.37.0
flash-attn
peft
torchmetrics
triton==3.3.0 # may be violate with openai-whisper

View File

@ -68,24 +68,26 @@ from transformers import (
Qwen2Config,
Qwen2ForCausalLM,
)
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# from icefall import diagnostics
from utils import get_rank, get_world_size
# from icefall.env import get_env_info
# from icefall import diagnostics
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
get_rank,
get_world_size,
setup_logger,
str2bool,
)
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method('spawn')
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
@ -272,7 +274,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 5000,
"valid_interval": 3000,
# "env_info": get_env_info(),
}
)
@ -332,6 +334,21 @@ def process_batch_vocalnet(batch: dict):
return messages, answer_cosyvoice_speech_token
def process_batch_speech_continuation(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
message = [
{
"role": "user",
"content": f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
]
# transcript = batch["supervisions"]["cut"][i].custom["text"]
messages.append(message)
return messages
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
@ -429,13 +446,13 @@ def compute_loss(
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
if params.dataset_format == "slam_omni":
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
elif params.dataset_format == "vocalnet":
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
elif params.dataset_format == "speech_continuation":
messages = process_batch_speech_continuation(batch)
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
@ -566,8 +583,11 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.encoder_projector.train()
# model.encoder_projector.train()
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
@ -583,6 +603,9 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -594,7 +617,7 @@ def train_one_epoch(
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
tag=f"zero-checkpoint-{params.batch_idx_train}",
client_state={},
exclude_frozen_parameters=True,
)
@ -602,18 +625,18 @@ def train_one_epoch(
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
tag=f"zero-checkpoint-{params.batch_idx_train}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
)
try:
with torch.amp.autocast("cuda", enabled=params.use_fp16):
@ -687,9 +710,9 @@ def run(rank, world_size, args):
fix_random_seed(params.seed)
setup_logger(f"{params.exp_dir}/log/log-train")
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
replace_whisper_encoder_forward()
@ -698,7 +721,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
@ -721,7 +743,7 @@ def run(rank, world_size, args):
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
if params.use_lora:
lora_config = LoraConfig(
@ -809,6 +831,9 @@ def run(rank, world_size, args):
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
# set params.batch_idx_train according to the checkpoint name
if "checkpoint-" in params.pretrained_model_path:
params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1])
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -842,21 +867,22 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
codec_len = (
len(c.custom["answer_cosyvoice_speech_token"])
if "answer_cosyvoice_speech_token" in c.custom
else len(c.custom["speech_token"])
)
if codec_len > 2200:
if c.duration < 1.0 or c.duration > 29.5:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
codec_len = (
len(c.custom["answer_cosyvoice_speech_token"])
if "answer_cosyvoice_speech_token" in c.custom
else len(c.custom["speech_token"])
)
if codec_len > 2200:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
)
return False
return True
if params.dataset_format == "slam_omni":
@ -865,6 +891,11 @@ def run(rank, world_size, args):
elif params.dataset_format == "vocalnet":
train_cuts = data_module.train_cuts_en_vocalnet()
valid_cuts = data_module.valid_cuts_en_vocalnet()
elif params.dataset_format == "speech_continuation":
# train_cuts = data_module.train_cuts_ultravox()
# train_cuts = data_module.train_cuts_gigaspeech()
train_cuts = data_module.train_cuts_librispeech()
valid_cuts = data_module.valid_cuts_ultravox()
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
@ -879,7 +910,7 @@ def run(rank, world_size, args):
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
# train_dl = data_module.valid_dataloaders(train_cuts)
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0:
@ -913,25 +944,25 @@ def run(rank, world_size, args):
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!")
@ -971,6 +1002,7 @@ def main():
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)