mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
add speech continuation pretraining
This commit is contained in:
parent
e65725810c
commit
f81363d324
@ -122,7 +122,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "stage 1: Compute fbank feature from huggingface"
|
||||
log "stage 6: Compute fbank feature from huggingface"
|
||||
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
||||
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
# --out-dir data/fbank_voice_assistant \
|
||||
@ -161,10 +161,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
--subset test --split test \
|
||||
--audio-key audio --text-key text \
|
||||
--prefix gigaspeech
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "stage 9: Compute fbank feature from huggingface"
|
||||
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
|
||||
--out-dir data/fbank_gigaspeech \
|
||||
@ -195,7 +192,7 @@ fi
|
||||
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "stage 11: Decoding EN, only support batch_size=1 for now."
|
||||
log "stage 11: Decoding EN, val set only support batch_size=1 for now."
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||
python3 ./qwen_omni/decode.py \
|
||||
@ -256,3 +253,71 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
--output-dir test_result
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
log "stage 15: Training Speech2Speech Model, adaptor only"
|
||||
exp_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=2
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 600 \
|
||||
--enable-musan False \
|
||||
--audio-key audio --text-key continuation \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--dataset-format speech_continuation \
|
||||
--start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \
|
||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||
fi
|
||||
|
||||
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
||||
log "stage 16: Training Speech2Speech Model, adaptor only"
|
||||
exp_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=4
|
||||
|
||||
latest_checkpoint_step=-1
|
||||
# Check if exp_dir exists and is a directory
|
||||
if [ -d "$exp_dir" ]; then
|
||||
# List directories matching checkpoint-* and find the one with the largest step number
|
||||
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# Extract step number using parameter expansion
|
||||
current_step=${checkpoint_name#checkpoint-}
|
||||
# Ensure current_step is a number
|
||||
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
latest_checkpoint_step=$current_step
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
train_cmd_args="--max-duration 1200 \
|
||||
--enable-musan False \
|
||||
--audio-key audio --text-key continuation \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--on-the-fly-feats True \
|
||||
--deepspeed \
|
||||
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--dataset-format speech_continuation \
|
||||
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False"
|
||||
|
||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
step=$latest_checkpoint_step
|
||||
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
else
|
||||
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# No pretrained model or sampler state dict needed for the first run
|
||||
fi
|
||||
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
@ -24,7 +24,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import interleave_datasets, load_dataset
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
@ -36,6 +36,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PerturbSpeed,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
@ -47,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
from lhotse.utils import fix_random_seed
|
||||
from speech_dataset import K2SpeechRecognitionDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from utils import str2bool
|
||||
from utils import get_rank, str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
@ -123,6 +124,14 @@ class AsrDataModule:
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-speed-perturb",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use on-the-fly speed perturbation. "
|
||||
"Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
@ -188,27 +197,27 @@ class AsrDataModule:
|
||||
group.add_argument(
|
||||
"--huggingface-dataset-path-or-name",
|
||||
type=str,
|
||||
default="/workspace/Belle_1.4M-SLAM-Omni",
|
||||
default=None,
|
||||
help="The path or name of the Huggingface dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--audio-key",
|
||||
type=str,
|
||||
default="question_audio",
|
||||
default="audio",
|
||||
help="The key in the Huggingface dataset containing the audio data",
|
||||
)
|
||||
group.add_argument(
|
||||
"--text-key",
|
||||
type=str,
|
||||
default="answer",
|
||||
default="text",
|
||||
help="The key in the Huggingface dataset containing the text data",
|
||||
)
|
||||
group.add_argument(
|
||||
"--resample-to-16kHz",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Resample audio to 16kHz. Default: False.",
|
||||
)
|
||||
# group.add_argument(
|
||||
# "--resample-to-16kHz",
|
||||
# type=str2bool,
|
||||
# default=True,
|
||||
# help="Resample audio to 16kHz. Default: False.",
|
||||
# )
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
@ -232,6 +241,8 @@ class AsrDataModule:
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats:
|
||||
transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
@ -260,9 +271,11 @@ class AsrDataModule:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
rank = get_rank()
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
@ -271,26 +284,6 @@ class AsrDataModule:
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
# if self.args.on_the_fly_feats:
|
||||
# # NOTE: the PerturbSpeed transform should be added only if we
|
||||
# # remove it from data prep stage.
|
||||
# # Add on-the-fly speed perturbation; since originally it would
|
||||
# # have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# # 3x more epochs.
|
||||
# # Speed perturbation probably should come first before
|
||||
# # concatenation, but in principle the transforms order doesn't have
|
||||
# # to be strict (e.g. could be randomized)
|
||||
# # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# # Drop feats to be on the safe side.
|
||||
# train = K2SpeechRecognitionDataset(
|
||||
# cut_transforms=transforms,
|
||||
# input_strategy=OnTheFlyFeatures(
|
||||
# WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
# ),
|
||||
# input_transforms=input_transforms,
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
@ -298,8 +291,7 @@ class AsrDataModule:
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
buffer_size=self.args.num_buckets * 1000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
@ -339,10 +331,10 @@ class AsrDataModule:
|
||||
CutSet for validation.
|
||||
"""
|
||||
logging.info("About to create dev dataset")
|
||||
|
||||
rank = get_rank()
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
@ -470,25 +462,231 @@ class AsrDataModule:
|
||||
)
|
||||
return {"test": VoiceAssistant_cuts}
|
||||
|
||||
# def train_cuts_en_vocalnet(self) -> CutSet:
|
||||
@lru_cache()
|
||||
def train_cuts_ultravox(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
if self.args.huggingface_dataset_path_or_name is not None:
|
||||
librispeech_path = (
|
||||
self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
|
||||
)
|
||||
people_speech_path = (
|
||||
self.args.huggingface_dataset_path_or_name + "/peoples_speech"
|
||||
)
|
||||
gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech"
|
||||
else:
|
||||
librispeech_path = "fixie-ai/librispeech_asr"
|
||||
people_speech_path = "fixie-ai/peoples_speech"
|
||||
gigaspeech_path = "fixie-ai/gigaspeech"
|
||||
# 148_688
|
||||
librispeech_other = load_dataset(
|
||||
librispeech_path, "other", split="train.500", streaming=True
|
||||
)
|
||||
# 104_014
|
||||
librispeech_clean_360 = load_dataset(
|
||||
librispeech_path, "clean", split="train.360", streaming=True
|
||||
)
|
||||
# 28_539
|
||||
librispeech_clean_100 = load_dataset(
|
||||
librispeech_path, "clean", split="train.100", streaming=True
|
||||
)
|
||||
|
||||
# 1_501_271
|
||||
people_speech_clean = load_dataset(
|
||||
people_speech_path, "clean", split="train", streaming=True
|
||||
)
|
||||
# 548_000
|
||||
people_speech_dirty_sa = load_dataset(
|
||||
people_speech_path, "dirty_sa", split="train", streaming=True
|
||||
)
|
||||
|
||||
# 8_266_422
|
||||
|
||||
gigaspeech = load_dataset(
|
||||
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
|
||||
)
|
||||
|
||||
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_100,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_other,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_360,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||
)
|
||||
|
||||
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
|
||||
people_speech_clean,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
|
||||
people_speech_dirty_sa,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
librispeech_clean_100_cuts,
|
||||
librispeech_clean_360_cuts,
|
||||
librispeech_other_cuts,
|
||||
gigaspeech_cuts,
|
||||
people_speech_clean_cuts,
|
||||
people_speech_dirty_sa_cuts,
|
||||
weights=[
|
||||
28539,
|
||||
104014,
|
||||
148688,
|
||||
8266422,
|
||||
1501271,
|
||||
548000,
|
||||
],
|
||||
)
|
||||
|
||||
# @lru_cache()
|
||||
# def train_cuts_ultravox(self) -> CutSet:
|
||||
# logging.info("About to get train cuts")
|
||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||
# )
|
||||
# return VoiceAssistant_cuts
|
||||
# keep_columns = ["audio", "text", "continuation", "id"]
|
||||
# librispeech_path="fixie-ai/librispeech_asr"
|
||||
# # 148_688
|
||||
# librispeech_other = load_dataset(librispeech_path, 'other', split='train.500', streaming=True)
|
||||
# # 104_014
|
||||
# librispeech_clean_360 = load_dataset(librispeech_path, 'clean', split='train.360', streaming=True)
|
||||
# # 28_539
|
||||
# librispeech_clean_100 = load_dataset(librispeech_path, 'clean', split='train.100', streaming=True)
|
||||
|
||||
# @lru_cache()
|
||||
# def valid_cuts_en_vocalnet(self) -> CutSet:
|
||||
# logging.info("About to get valid cuts")
|
||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||
# )
|
||||
# return VoiceAssistant_cuts
|
||||
# cols_to_remove = librispeech_clean_100.column_names
|
||||
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
|
||||
# librispeech_clean_100 = librispeech_clean_100.remove_columns(cols_to_remove)
|
||||
# librispeech_clean_360 = librispeech_clean_360.remove_columns(cols_to_remove)
|
||||
# librispeech_other = librispeech_other.remove_columns(cols_to_remove)
|
||||
# people_speech_path="fixie-ai/peoples_speech"
|
||||
# # 1_501_271
|
||||
# people_speech_clean = load_dataset(people_speech_path, 'clean', split='train', streaming=True)
|
||||
# # 548_000
|
||||
# people_speech_dirty_sa = load_dataset(people_speech_path, 'dirty_sa', split='train', streaming=True)
|
||||
# cols_to_remove = people_speech_clean.column_names
|
||||
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
|
||||
# people_speech_clean = people_speech_clean.remove_columns(cols_to_remove)
|
||||
# people_speech_dirty_sa = people_speech_dirty_sa.remove_columns(cols_to_remove)
|
||||
|
||||
# @lru_cache()
|
||||
# def test_cuts_en_vocalnet(self) -> CutSet:
|
||||
# logging.info("About to get test cuts")
|
||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||
# # 8_266_422
|
||||
# gigaspeech_path="fixie-ai/gigaspeech"
|
||||
# gigaspeech = load_dataset(gigaspeech_path, 'xl-empty-audio-removed', split='train', streaming=True)
|
||||
# # first rename segment_id to id
|
||||
# gigaspeech = gigaspeech.rename_column("segment_id", "id")
|
||||
# cols_to_remove = gigaspeech.column_names
|
||||
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
|
||||
# gigaspeech = gigaspeech.remove_columns(cols_to_remove)
|
||||
|
||||
# total_item = 104014 + 28539 + 8266422 + 1501271 + 548000 + 148688
|
||||
# final_datasets = interleave_datasets([
|
||||
# librispeech_clean_100,
|
||||
# librispeech_clean_360,
|
||||
# gigaspeech,
|
||||
# people_speech_clean,
|
||||
# people_speech_dirty_sa,
|
||||
# librispeech_other,
|
||||
# ], probabilities=[
|
||||
# 28539 / total_item,
|
||||
# 104014 / total_item,
|
||||
# 8266422 / total_item,
|
||||
# 1501271 / total_item,
|
||||
# 548000 / total_item,
|
||||
# 148688 / total_item,
|
||||
# ])
|
||||
|
||||
# train_cuts = CutSet.from_huggingface_dataset(
|
||||
# final_datasets, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||
# )
|
||||
# return VoiceAssistant_cuts
|
||||
|
||||
# return train_cuts
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts_ultravox(self) -> CutSet:
|
||||
logging.info("About to get valid cuts")
|
||||
librispeech_path = "fixie-ai/librispeech_asr"
|
||||
librispeech_clean_valid = load_dataset(
|
||||
librispeech_path, "clean", split="validation", streaming=True
|
||||
)
|
||||
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_valid,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
return librispeech_clean_valid_cuts
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_librispeech(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
|
||||
# librispeech_path="fixie-ai/librispeech_asr"
|
||||
librispeech_path = "/workspace/slam/librispeech_asr"
|
||||
# 148_688
|
||||
librispeech_other = load_dataset(
|
||||
librispeech_path, "other", split="train.500", streaming=True
|
||||
)
|
||||
# 104_014
|
||||
librispeech_clean_360 = load_dataset(
|
||||
librispeech_path, "clean", split="train.360", streaming=True
|
||||
)
|
||||
# 28_539
|
||||
librispeech_clean_100 = load_dataset(
|
||||
librispeech_path, "clean", split="train.100", streaming=True
|
||||
)
|
||||
|
||||
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_100,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_other,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||
librispeech_clean_360,
|
||||
audio_key=self.args.audio_key,
|
||||
text_key=self.args.text_key,
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
librispeech_clean_100_cuts,
|
||||
librispeech_clean_360_cuts,
|
||||
librispeech_other_cuts,
|
||||
weights=[
|
||||
28539,
|
||||
104014,
|
||||
148688,
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_gigaspeech(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
gigaspeech_path = "fixie-ai/gigaspeech"
|
||||
gigaspeech = load_dataset(
|
||||
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
|
||||
)
|
||||
|
||||
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||
)
|
||||
|
||||
return gigaspeech_cuts
|
||||
|
@ -10,3 +10,4 @@ transformers>=4.37.0
|
||||
flash-attn
|
||||
peft
|
||||
torchmetrics
|
||||
triton==3.3.0 # may be violate with openai-whisper
|
||||
|
@ -68,24 +68,26 @@ from transformers import (
|
||||
Qwen2Config,
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
# from icefall import diagnostics
|
||||
from utils import get_rank, get_world_size
|
||||
# from icefall.env import get_env_info
|
||||
# from icefall import diagnostics
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
try:
|
||||
torch.multiprocessing.set_start_method('spawn')
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
@ -272,7 +274,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 5000,
|
||||
"valid_interval": 3000,
|
||||
# "env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -332,6 +334,21 @@ def process_batch_vocalnet(batch: dict):
|
||||
return messages, answer_cosyvoice_speech_token
|
||||
|
||||
|
||||
def process_batch_speech_continuation(batch: dict):
|
||||
messages = []
|
||||
for i in range(len(batch["supervisions"]["text"])):
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}",
|
||||
},
|
||||
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
|
||||
]
|
||||
# transcript = batch["supervisions"]["cut"][i].custom["text"]
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
@ -429,13 +446,13 @@ def compute_loss(
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
|
||||
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
|
||||
if params.dataset_format == "slam_omni":
|
||||
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
|
||||
elif params.dataset_format == "vocalnet":
|
||||
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
|
||||
elif params.dataset_format == "speech_continuation":
|
||||
messages = process_batch_speech_continuation(batch)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||
|
||||
@ -566,8 +583,11 @@ def train_one_epoch(
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.encoder_projector.train()
|
||||
|
||||
# model.encoder_projector.train()
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
@ -583,6 +603,9 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
@ -594,7 +617,7 @@ def train_one_epoch(
|
||||
if batch_idx != 0:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
tag=f"zero-checkpoint-{params.batch_idx_train}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
@ -602,18 +625,18 @@ def train_one_epoch(
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
|
||||
tag=f"zero-checkpoint-{params.batch_idx_train}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
|
||||
)
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
|
||||
)
|
||||
try:
|
||||
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||
@ -687,9 +710,9 @@ def run(rank, world_size, args):
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
if rank == 0:
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
replace_whisper_encoder_forward()
|
||||
@ -698,7 +721,6 @@ def run(rank, world_size, args):
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
@ -721,7 +743,7 @@ def run(rank, world_size, args):
|
||||
if not params.unfreeze_llm:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
|
||||
else:
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
@ -809,6 +831,9 @@ def run(rank, world_size, args):
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||
# set params.batch_idx_train according to the checkpoint name
|
||||
if "checkpoint-" in params.pretrained_model_path:
|
||||
params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1])
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -842,21 +867,22 @@ def run(rank, world_size, args):
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 30.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
codec_len = (
|
||||
len(c.custom["answer_cosyvoice_speech_token"])
|
||||
if "answer_cosyvoice_speech_token" in c.custom
|
||||
else len(c.custom["speech_token"])
|
||||
)
|
||||
if codec_len > 2200:
|
||||
if c.duration < 1.0 or c.duration > 29.5:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
|
||||
codec_len = (
|
||||
len(c.custom["answer_cosyvoice_speech_token"])
|
||||
if "answer_cosyvoice_speech_token" in c.custom
|
||||
else len(c.custom["speech_token"])
|
||||
)
|
||||
if codec_len > 2200:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
if params.dataset_format == "slam_omni":
|
||||
@ -865,6 +891,11 @@ def run(rank, world_size, args):
|
||||
elif params.dataset_format == "vocalnet":
|
||||
train_cuts = data_module.train_cuts_en_vocalnet()
|
||||
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||
elif params.dataset_format == "speech_continuation":
|
||||
# train_cuts = data_module.train_cuts_ultravox()
|
||||
# train_cuts = data_module.train_cuts_gigaspeech()
|
||||
train_cuts = data_module.train_cuts_librispeech()
|
||||
valid_cuts = data_module.valid_cuts_ultravox()
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||
|
||||
@ -879,7 +910,7 @@ def run(rank, world_size, args):
|
||||
train_dl = data_module.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
# train_dl = data_module.valid_dataloaders(train_cuts)
|
||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
@ -913,25 +944,25 @@ def run(rank, world_size, args):
|
||||
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}",
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
|
||||
)
|
||||
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
@ -971,6 +1002,7 @@ def main():
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user