mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add training stage
This commit is contained in:
parent
e6897b10fa
commit
6b69276b19
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
||||
set -eou pipefail
|
||||
|
||||
stage=2
|
||||
stop_stage=2
|
||||
stage=1
|
||||
stop_stage=1
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
@ -20,8 +20,10 @@ log() {
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: "
|
||||
|
||||
|
||||
cd /workspace/slam/lhotse
|
||||
git config --global --add safe.directory /workspace/slam/lhotse
|
||||
pip install -e '.[dev]'
|
||||
cd -
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
@ -43,3 +45,18 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
--use-lora False # --on-the-fly-feats True
|
||||
|
||||
fi
|
||||
|
||||
ngpu=2
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: "
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./slam_omni/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
fi
|
1
egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json
Symbolic link
1
egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR_LLM/whisper_llm_zh/ds_config_zero1.json
|
@ -395,7 +395,12 @@ def compute_loss(
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
answers = batch["supervisions"]["text"]
|
||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||
@ -426,7 +431,7 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
feature_lens = supervisions["num_frames"]
|
||||
feature_lens = batch["supervisions"]["num_frames"]
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
@ -848,7 +853,6 @@ def display_and_save_batch(
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user