mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add training stage
This commit is contained in:
parent
e6897b10fa
commit
6b69276b19
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=2
|
stage=1
|
||||||
stop_stage=2
|
stop_stage=1
|
||||||
# All files generated by this script are saved in "data".
|
# All files generated by this script are saved in "data".
|
||||||
# You can safely remove "data" and rerun this script to regenerate it.
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
mkdir -p data
|
mkdir -p data
|
||||||
@ -20,8 +20,10 @@ log() {
|
|||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "stage 0: "
|
log "stage 0: "
|
||||||
|
cd /workspace/slam/lhotse
|
||||||
|
git config --global --add safe.directory /workspace/slam/lhotse
|
||||||
|
pip install -e '.[dev]'
|
||||||
|
cd -
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
@ -43,3 +45,18 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
--use-lora False # --on-the-fly-feats True
|
--use-lora False # --on-the-fly-feats True
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
ngpu=2
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "stage 3: "
|
||||||
|
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||||
|
--max-duration 200 \
|
||||||
|
--exp-dir ./slam_omni/exp_test \
|
||||||
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --unfreeze-llm True
|
||||||
|
fi
|
1
egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json
Symbolic link
1
egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../ASR_LLM/whisper_llm_zh/ds_config_zero1.json
|
@ -395,7 +395,12 @@ def compute_loss(
|
|||||||
feature = feature.transpose(1, 2) # (N, C, T)
|
feature = feature.transpose(1, 2) # (N, C, T)
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
|
answers = batch["supervisions"]["text"]
|
||||||
|
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||||
|
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
||||||
|
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
# remove spaces in texts
|
# remove spaces in texts
|
||||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||||
@ -426,7 +431,7 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
feature_lens = supervisions["num_frames"]
|
feature_lens = batch["supervisions"]["num_frames"]
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
@ -848,7 +853,6 @@ def display_and_save_batch(
|
|||||||
logging.info(f"Saving batch to {filename}")
|
logging.info(f"Saving batch to {filename}")
|
||||||
torch.save(batch, filename)
|
torch.save(batch, filename)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
features = batch["inputs"]
|
features = batch["inputs"]
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user