add training stage

This commit is contained in:
root 2025-04-11 06:50:25 +00:00
parent e6897b10fa
commit 6b69276b19
3 changed files with 29 additions and 7 deletions

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
set -eou pipefail set -eou pipefail
stage=2 stage=1
stop_stage=2 stop_stage=1
# All files generated by this script are saved in "data". # All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it. # You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data mkdir -p data
@ -20,8 +20,10 @@ log() {
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: " log "stage 0: "
cd /workspace/slam/lhotse
git config --global --add safe.directory /workspace/slam/lhotse
pip install -e '.[dev]'
cd -
fi fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
@ -43,3 +45,18 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
--use-lora False # --on-the-fly-feats True --use-lora False # --on-the-fly-feats True
fi fi
ngpu=2
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: "
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 200 \
--exp-dir ./slam_omni/exp_test \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
fi

View File

@ -0,0 +1 @@
../../ASR_LLM/whisper_llm_zh/ds_config_zero1.json

View File

@ -395,7 +395,12 @@ def compute_loss(
feature = feature.transpose(1, 2) # (N, C, T) feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"]
answers = batch["supervisions"]["text"]
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# remove spaces in texts # remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts] texts = [normalize_text_alimeeting(text) for text in texts]
@ -426,7 +431,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
feature_lens = supervisions["num_frames"] feature_lens = batch["supervisions"]["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
@ -848,7 +853,6 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}") logging.info(f"Saving batch to {filename}")
torch.save(batch, filename) torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"] features = batch["inputs"]
logging.info(f"features shape: {features.shape}") logging.info(f"features shape: {features.shape}")