format train.py

minor fix train.py
This commit is contained in:
Yifan Yang 2025-05-07 12:37:19 +00:00
parent 23b5a7ce3e
commit 211c01bc1d

View File

@ -18,18 +18,6 @@
# limitations under the License. # limitations under the License.
""" """
Usage: Usage:
# fine-tuning with whisper and Qwen2
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \ --exp-dir ./whisper_llm_zh/exp_test \
@ -39,7 +27,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --unfreeze-llm True --use-lora True \
--unfreeze-llm True
""" """
import argparse import argparse
@ -333,7 +322,6 @@ def compute_loss(
feature = feature.to(device) feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T) feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
@ -378,7 +366,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer, tokenizer: AutoTokenizer,
model: nn.Module, model: nn.Module,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,