diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index b61241974..6c7393379 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni set -eou pipefail -stage=2 -stop_stage=2 +stage=1 +stop_stage=1 # All files generated by this script are saved in "data". # You can safely remove "data" and rerun this script to regenerate it. mkdir -p data @@ -20,8 +20,10 @@ log() { if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "stage 0: " - - + cd /workspace/slam/lhotse + git config --global --add safe.directory /workspace/slam/lhotse + pip install -e '.[dev]' + cd - fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then @@ -43,3 +45,18 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --use-lora False # --on-the-fly-feats True fi + +ngpu=2 +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "stage 3: " +torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ + --max-duration 200 \ + --exp-dir ./slam_omni/exp_test \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./slam_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True +fi \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json b/egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json new file mode 120000 index 000000000..4fbacea32 --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json @@ -0,0 +1 @@ +../../ASR_LLM/whisper_llm_zh/ds_config_zero1.json \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index 1c3ccd2c6..f05e5c1ac 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -395,7 +395,12 @@ def compute_loss( feature = feature.transpose(1, 2) # (N, C, T) batch_idx_train = params.batch_idx_train - supervisions = batch["supervisions"] + + answers = batch["supervisions"]["text"] + questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] + answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] + last_questions = [question.split(': ')[-1].strip() for question in questions_with_history] + texts = batch["supervisions"]["text"] # remove spaces in texts texts = [normalize_text_alimeeting(text) for text in texts] @@ -426,7 +431,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - feature_lens = supervisions["num_frames"] + feature_lens = batch["supervisions"]["num_frames"] info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. @@ -848,7 +853,6 @@ def display_and_save_batch( logging.info(f"Saving batch to {filename}") torch.save(batch, filename) - supervisions = batch["supervisions"] features = batch["inputs"] logging.info(f"features shape: {features.shape}")