mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
format train.py
minor fix train.py
This commit is contained in:
parent
23b5a7ce3e
commit
211c01bc1d
@ -18,18 +18,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
# fine-tuning with whisper and Qwen2
|
|
||||||
pip install huggingface_hub['cli']
|
|
||||||
mkdir -p models/whisper models/qwen
|
|
||||||
|
|
||||||
# For aishell fine-tuned whisper model
|
|
||||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
|
||||||
# For multi-hans fine-tuned whisper model
|
|
||||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
|
||||||
|
|
||||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
|
||||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
|
||||||
|
|
||||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir ./whisper_llm_zh/exp_test \
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
@ -39,7 +27,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
|||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--use-lora True --unfreeze-llm True
|
--use-lora True \
|
||||||
|
--unfreeze-llm True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -333,7 +322,6 @@ def compute_loss(
|
|||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
feature = feature.transpose(1, 2) # (N, C, T)
|
feature = feature.transpose(1, 2) # (N, C, T)
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
@ -378,7 +366,7 @@ def compute_loss(
|
|||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: whisper.tokenizer.Tokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user