add speech continuation pretraining

This commit is contained in:
root 2025-05-15 14:16:51 +00:00
parent e65725810c
commit f81363d324
4 changed files with 391 additions and 95 deletions

View File

@ -122,7 +122,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 1: Compute fbank feature from huggingface" log "stage 6: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \ # CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ # --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_voice_assistant \ # --out-dir data/fbank_voice_assistant \
@ -161,10 +161,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
--subset test --split test \ --subset test --split test \
--audio-key audio --text-key text \ --audio-key audio --text-key text \
--prefix gigaspeech --prefix gigaspeech
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "stage 9: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \ CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
--out-dir data/fbank_gigaspeech \ --out-dir data/fbank_gigaspeech \
@ -195,7 +192,7 @@ fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "stage 11: Decoding EN, only support batch_size=1 for now." log "stage 11: Decoding EN, val set only support batch_size=1 for now."
exp_dir=./qwen_omni/exp_speech2speech_en_continue exp_dir=./qwen_omni/exp_speech2speech_en_continue
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd - # cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \ python3 ./qwen_omni/decode.py \
@ -256,3 +253,71 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
--output-dir test_result --output-dir test_result
done done
fi fi
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
log "stage 15: Training Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 600 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
fi
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
log "stage 16: Training Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text
ngpu=4
latest_checkpoint_step=-1
# Check if exp_dir exists and is a directory
if [ -d "$exp_dir" ]; then
# List directories matching checkpoint-* and find the one with the largest step number
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
# Extract step number using parameter expansion
current_step=${checkpoint_name#checkpoint-}
# Ensure current_step is a number
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
latest_checkpoint_step=$current_step
fi
done
fi
train_cmd_args="--max-duration 1200 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False"
if [ "$latest_checkpoint_step" -ge 0 ]; then
log "Continuing training from checkpoint-$latest_checkpoint_step"
step=$latest_checkpoint_step
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
else
log "Starting training from scratch as no checkpoint was found in $exp_dir"
# No pretrained model or sampler state dict needed for the first run
fi
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
$train_cmd_args
fi

View File

@ -24,7 +24,7 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from datasets import load_dataset from datasets import interleave_datasets, load_dataset
from lhotse import ( from lhotse import (
CutSet, CutSet,
WhisperFbank, WhisperFbank,
@ -36,6 +36,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
PerturbSpeed,
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
@ -47,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import str2bool from utils import get_rank, str2bool
class _SeedWorkers: class _SeedWorkers:
@ -123,6 +124,14 @@ class AsrDataModule:
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available.", "if available.",
) )
group.add_argument(
"--on-the-fly-speed-perturb",
type=str2bool,
default=True,
help="When enabled, use on-the-fly speed perturbation. "
"Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
@ -188,27 +197,27 @@ class AsrDataModule:
group.add_argument( group.add_argument(
"--huggingface-dataset-path-or-name", "--huggingface-dataset-path-or-name",
type=str, type=str,
default="/workspace/Belle_1.4M-SLAM-Omni", default=None,
help="The path or name of the Huggingface dataset", help="The path or name of the Huggingface dataset",
) )
group.add_argument( group.add_argument(
"--audio-key", "--audio-key",
type=str, type=str,
default="question_audio", default="audio",
help="The key in the Huggingface dataset containing the audio data", help="The key in the Huggingface dataset containing the audio data",
) )
group.add_argument( group.add_argument(
"--text-key", "--text-key",
type=str, type=str,
default="answer", default="text",
help="The key in the Huggingface dataset containing the text data", help="The key in the Huggingface dataset containing the text data",
) )
group.add_argument( # group.add_argument(
"--resample-to-16kHz", # "--resample-to-16kHz",
type=str2bool, # type=str2bool,
default=True, # default=True,
help="Resample audio to 16kHz. Default: False.", # help="Resample audio to 16kHz. Default: False.",
) # )
def train_dataloaders( def train_dataloaders(
self, self,
@ -232,6 +241,8 @@ class AsrDataModule:
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats:
transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
@ -260,9 +271,11 @@ class AsrDataModule:
logging.info("Disable SpecAugment") logging.info("Disable SpecAugment")
logging.info("About to create train dataset") logging.info("About to create train dataset")
rank = get_rank()
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
) )
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)(),
@ -271,26 +284,6 @@ class AsrDataModule:
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
# if self.args.on_the_fly_feats:
# # NOTE: the PerturbSpeed transform should be added only if we
# # remove it from data prep stage.
# # Add on-the-fly speed perturbation; since originally it would
# # have increased epoch size by 3, we will apply prob 2/3 and use
# # 3x more epochs.
# # Speed perturbation probably should come first before
# # concatenation, but in principle the transforms order doesn't have
# # to be strict (e.g. could be randomized)
# # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# # Drop feats to be on the safe side.
# train = K2SpeechRecognitionDataset(
# cut_transforms=transforms,
# input_strategy=OnTheFlyFeatures(
# WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
# ),
# input_transforms=input_transforms,
# return_cuts=self.args.return_cuts,
# )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler( train_sampler = DynamicBucketingSampler(
@ -298,8 +291,7 @@ class AsrDataModule:
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000, buffer_size=self.args.num_buckets * 1000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
@ -339,10 +331,10 @@ class AsrDataModule:
CutSet for validation. CutSet for validation.
""" """
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
rank = get_rank()
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
) )
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)(),
@ -470,25 +462,231 @@ class AsrDataModule:
) )
return {"test": VoiceAssistant_cuts} return {"test": VoiceAssistant_cuts}
# def train_cuts_en_vocalnet(self) -> CutSet: @lru_cache()
def train_cuts_ultravox(self) -> CutSet:
logging.info("About to get train cuts")
if self.args.huggingface_dataset_path_or_name is not None:
librispeech_path = (
self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
)
people_speech_path = (
self.args.huggingface_dataset_path_or_name + "/peoples_speech"
)
gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech"
else:
librispeech_path = "fixie-ai/librispeech_asr"
people_speech_path = "fixie-ai/peoples_speech"
gigaspeech_path = "fixie-ai/gigaspeech"
# 148_688
librispeech_other = load_dataset(
librispeech_path, "other", split="train.500", streaming=True
)
# 104_014
librispeech_clean_360 = load_dataset(
librispeech_path, "clean", split="train.360", streaming=True
)
# 28_539
librispeech_clean_100 = load_dataset(
librispeech_path, "clean", split="train.100", streaming=True
)
# 1_501_271
people_speech_clean = load_dataset(
people_speech_path, "clean", split="train", streaming=True
)
# 548_000
people_speech_dirty_sa = load_dataset(
people_speech_path, "dirty_sa", split="train", streaming=True
)
# 8_266_422
gigaspeech = load_dataset(
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
)
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_100,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_other_cuts = CutSet.from_huggingface_dataset(
librispeech_other,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_360,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
gigaspeech_cuts = CutSet.from_huggingface_dataset(
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
)
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
people_speech_clean,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
people_speech_dirty_sa,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
return CutSet.mux(
librispeech_clean_100_cuts,
librispeech_clean_360_cuts,
librispeech_other_cuts,
gigaspeech_cuts,
people_speech_clean_cuts,
people_speech_dirty_sa_cuts,
weights=[
28539,
104014,
148688,
8266422,
1501271,
548000,
],
)
# @lru_cache()
# def train_cuts_ultravox(self) -> CutSet:
# logging.info("About to get train cuts") # logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy( # keep_columns = ["audio", "text", "continuation", "id"]
# self.args.manifest_dir / "cuts_debug.jsonl.gz" # librispeech_path="fixie-ai/librispeech_asr"
# ) # # 148_688
# return VoiceAssistant_cuts # librispeech_other = load_dataset(librispeech_path, 'other', split='train.500', streaming=True)
# # 104_014
# librispeech_clean_360 = load_dataset(librispeech_path, 'clean', split='train.360', streaming=True)
# # 28_539
# librispeech_clean_100 = load_dataset(librispeech_path, 'clean', split='train.100', streaming=True)
# @lru_cache() # cols_to_remove = librispeech_clean_100.column_names
# def valid_cuts_en_vocalnet(self) -> CutSet: # cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# logging.info("About to get valid cuts") # librispeech_clean_100 = librispeech_clean_100.remove_columns(cols_to_remove)
# VoiceAssistant_cuts = load_manifest_lazy( # librispeech_clean_360 = librispeech_clean_360.remove_columns(cols_to_remove)
# self.args.manifest_dir / "cuts_debug.jsonl.gz" # librispeech_other = librispeech_other.remove_columns(cols_to_remove)
# ) # people_speech_path="fixie-ai/peoples_speech"
# return VoiceAssistant_cuts # # 1_501_271
# people_speech_clean = load_dataset(people_speech_path, 'clean', split='train', streaming=True)
# # 548_000
# people_speech_dirty_sa = load_dataset(people_speech_path, 'dirty_sa', split='train', streaming=True)
# cols_to_remove = people_speech_clean.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# people_speech_clean = people_speech_clean.remove_columns(cols_to_remove)
# people_speech_dirty_sa = people_speech_dirty_sa.remove_columns(cols_to_remove)
# @lru_cache() # # 8_266_422
# def test_cuts_en_vocalnet(self) -> CutSet: # gigaspeech_path="fixie-ai/gigaspeech"
# logging.info("About to get test cuts") # gigaspeech = load_dataset(gigaspeech_path, 'xl-empty-audio-removed', split='train', streaming=True)
# VoiceAssistant_cuts = load_manifest_lazy( # # first rename segment_id to id
# self.args.manifest_dir / "cuts_debug.jsonl.gz" # gigaspeech = gigaspeech.rename_column("segment_id", "id")
# cols_to_remove = gigaspeech.column_names
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
# gigaspeech = gigaspeech.remove_columns(cols_to_remove)
# total_item = 104014 + 28539 + 8266422 + 1501271 + 548000 + 148688
# final_datasets = interleave_datasets([
# librispeech_clean_100,
# librispeech_clean_360,
# gigaspeech,
# people_speech_clean,
# people_speech_dirty_sa,
# librispeech_other,
# ], probabilities=[
# 28539 / total_item,
# 104014 / total_item,
# 8266422 / total_item,
# 1501271 / total_item,
# 548000 / total_item,
# 148688 / total_item,
# ])
# train_cuts = CutSet.from_huggingface_dataset(
# final_datasets, audio_key=self.args.audio_key, text_key=self.args.text_key
# ) # )
# return VoiceAssistant_cuts
# return train_cuts
@lru_cache()
def valid_cuts_ultravox(self) -> CutSet:
logging.info("About to get valid cuts")
librispeech_path = "fixie-ai/librispeech_asr"
librispeech_clean_valid = load_dataset(
librispeech_path, "clean", split="validation", streaming=True
)
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_valid,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
return librispeech_clean_valid_cuts
@lru_cache()
def train_cuts_librispeech(self) -> CutSet:
logging.info("About to get train cuts")
# librispeech_path="fixie-ai/librispeech_asr"
librispeech_path = "/workspace/slam/librispeech_asr"
# 148_688
librispeech_other = load_dataset(
librispeech_path, "other", split="train.500", streaming=True
)
# 104_014
librispeech_clean_360 = load_dataset(
librispeech_path, "clean", split="train.360", streaming=True
)
# 28_539
librispeech_clean_100 = load_dataset(
librispeech_path, "clean", split="train.100", streaming=True
)
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_100,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_other_cuts = CutSet.from_huggingface_dataset(
librispeech_other,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
librispeech_clean_360,
audio_key=self.args.audio_key,
text_key=self.args.text_key,
)
return CutSet.mux(
librispeech_clean_100_cuts,
librispeech_clean_360_cuts,
librispeech_other_cuts,
weights=[
28539,
104014,
148688,
],
)
@lru_cache()
def train_cuts_gigaspeech(self) -> CutSet:
logging.info("About to get train cuts")
gigaspeech_path = "fixie-ai/gigaspeech"
gigaspeech = load_dataset(
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
)
gigaspeech_cuts = CutSet.from_huggingface_dataset(
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
)
return gigaspeech_cuts

View File

@ -10,3 +10,4 @@ transformers>=4.37.0
flash-attn flash-attn
peft peft
torchmetrics torchmetrics
triton==3.3.0 # may be violate with openai-whisper

View File

@ -68,24 +68,26 @@ from transformers import (
Qwen2Config, Qwen2Config,
Qwen2ForCausalLM, Qwen2ForCausalLM,
) )
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# from icefall import diagnostics
from utils import get_rank, get_world_size
# from icefall.env import get_env_info # from icefall.env import get_env_info
# from icefall import diagnostics
from utils import ( # filter_uneven_sized_batch, from utils import ( # filter_uneven_sized_batch,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_rank,
get_world_size,
setup_logger, setup_logger,
str2bool, str2bool,
) )
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
DEFAULT_SPEECH_TOKEN = "<speech>" DEFAULT_SPEECH_TOKEN = "<speech>"
try: try:
torch.multiprocessing.set_start_method('spawn') torch.multiprocessing.set_start_method("spawn")
except RuntimeError: except RuntimeError:
pass pass
def set_batch_count(model: nn.Module, batch_count: float) -> None: def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules(): for module in model.modules():
if hasattr(module, "batch_count"): if hasattr(module, "batch_count"):
@ -272,7 +274,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 5000, "valid_interval": 3000,
# "env_info": get_env_info(), # "env_info": get_env_info(),
} }
) )
@ -332,6 +334,21 @@ def process_batch_vocalnet(batch: dict):
return messages, answer_cosyvoice_speech_token return messages, answer_cosyvoice_speech_token
def process_batch_speech_continuation(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
message = [
{
"role": "user",
"content": f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
]
# transcript = batch["supervisions"]["cut"][i].custom["text"]
messages.append(message)
return messages
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
@ -429,13 +446,13 @@ def compute_loss(
feature = feature.to(device) feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T) feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet # WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
if params.dataset_format == "slam_omni": if params.dataset_format == "slam_omni":
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch) messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
elif params.dataset_format == "vocalnet": elif params.dataset_format == "vocalnet":
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch) messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
elif params.dataset_format == "speech_continuation":
messages = process_batch_speech_continuation(batch)
else: else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}") raise ValueError(f"Unknown dataset format: {params.dataset_format}")
@ -566,8 +583,11 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should The rank of the node in DDP training. If no DDP is used, it should
be set to 0. be set to 0.
""" """
model.encoder_projector.train() # model.encoder_projector.train()
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
@ -583,6 +603,9 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -594,7 +617,7 @@ def train_one_epoch(
if batch_idx != 0: if batch_idx != 0:
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"zero-checkpoint-{params.batch_idx_train}",
client_state={}, client_state={},
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
@ -602,18 +625,18 @@ def train_one_epoch(
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"zero-checkpoint-{params.batch_idx_train}",
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
# save sampler state dict into checkpoint # save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict() sampler_state_dict = train_dl.sampler.state_dict()
torch.save( torch.save(
sampler_state_dict, sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt", f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
) )
os.system( os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
) )
try: try:
with torch.amp.autocast("cuda", enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
@ -687,9 +710,9 @@ def run(rank, world_size, args):
fix_random_seed(params.seed) fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
@ -698,7 +721,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters(): for name, param in speech_encoder.named_parameters():
param.requires_grad = False param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
@ -721,7 +743,7 @@ def run(rank, world_size, args):
if not params.unfreeze_llm: if not params.unfreeze_llm:
for name, param in llm.named_parameters(): for name, param in llm.named_parameters():
param.requires_grad = False param.requires_grad = False
llm.eval()
else: else:
if params.use_lora: if params.use_lora:
lora_config = LoraConfig( lora_config = LoraConfig(
@ -809,6 +831,9 @@ def run(rank, world_size, args):
if params.pretrained_model_path: if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
# set params.batch_idx_train according to the checkpoint name
if "checkpoint-" in params.pretrained_model_path:
params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1])
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -842,11 +867,12 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 30.0: if c.duration < 1.0 or c.duration > 29.5:
# logging.warning( logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) )
return False return False
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
codec_len = ( codec_len = (
len(c.custom["answer_cosyvoice_speech_token"]) len(c.custom["answer_cosyvoice_speech_token"])
if "answer_cosyvoice_speech_token" in c.custom if "answer_cosyvoice_speech_token" in c.custom
@ -865,6 +891,11 @@ def run(rank, world_size, args):
elif params.dataset_format == "vocalnet": elif params.dataset_format == "vocalnet":
train_cuts = data_module.train_cuts_en_vocalnet() train_cuts = data_module.train_cuts_en_vocalnet()
valid_cuts = data_module.valid_cuts_en_vocalnet() valid_cuts = data_module.valid_cuts_en_vocalnet()
elif params.dataset_format == "speech_continuation":
# train_cuts = data_module.train_cuts_ultravox()
# train_cuts = data_module.train_cuts_gigaspeech()
train_cuts = data_module.train_cuts_librispeech()
valid_cuts = data_module.valid_cuts_ultravox()
else: else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}") raise ValueError(f"Unknown dataset format: {params.dataset_format}")
@ -879,7 +910,7 @@ def run(rank, world_size, args):
train_dl = data_module.train_dataloaders( train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
# train_dl = data_module.valid_dataloaders(train_cuts)
valid_dl = data_module.valid_dataloaders(valid_cuts) valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
@ -913,25 +944,25 @@ def run(rank, world_size, args):
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}", tag=f"zero-epoch-{params.cur_epoch}",
client_state={}, client_state={},
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"epoch-{params.cur_epoch}", tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
# save sampler state dict into checkpoint # save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict() sampler_state_dict = train_dl.sampler.state_dict()
torch.save( torch.save(
sampler_state_dict, sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
) )
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!") logging.info("Done!")
@ -971,6 +1002,7 @@ def main():
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args) run(rank=rank, world_size=world_size, args=args)