From 211c01bc1dcb3a403aac69838eca8b2b480b03e1 Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Wed, 7 May 2025 12:37:19 +0000 Subject: [PATCH] format train.py minor fix train.py --- egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 7947a60a5..6e43bf83f 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -18,18 +18,6 @@ # limitations under the License. """ Usage: -# fine-tuning with whisper and Qwen2 -pip install huggingface_hub['cli'] -mkdir -p models/whisper models/qwen - -# For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt -# For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - -# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct - torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --max-duration 200 \ --exp-dir ./whisper_llm_zh/exp_test \ @@ -39,7 +27,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora True --unfreeze-llm True + --use-lora True \ + --unfreeze-llm True """ import argparse @@ -333,7 +322,6 @@ def compute_loss( feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) - batch_idx_train = params.batch_idx_train supervisions = batch["supervisions"] texts = batch["supervisions"]["text"] @@ -378,7 +366,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, + tokenizer: AutoTokenizer, model: nn.Module, valid_dl: torch.utils.data.DataLoader, world_size: int = 1,