mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix speech_llm recipe (#1936)
* fix training/decoding scripts, cleanup unused code, and ensure compliance with style checks --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
cc2e64a6aa
commit
cd7caf12df
@ -55,7 +55,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
|||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--use-lora False --unfreeze-llm False
|
--use-lora False \
|
||||||
|
--unfreeze-llm False
|
||||||
|
|
||||||
# Then, we jointly train the projector and LLM LoRA modules.
|
# Then, we jointly train the projector and LLM LoRA modules.
|
||||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||||
@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
|||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--use-lora True --unfreeze-llm True
|
--use-lora True \
|
||||||
|
--unfreeze-llm True \
|
||||||
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
|
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \
|
|||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--use-lora True --dataset aishell
|
--use-lora True \
|
||||||
|
--dataset aishell
|
||||||
```
|
```
|
||||||
|
@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -357,43 +357,6 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "beam-search".
|
Return a dict, whose key may be "beam-search".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
|
||||||
"""
|
|
||||||
Text normalization similar to M2MeT challenge baseline.
|
|
||||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
|
||||||
"""
|
|
||||||
if normalize == "none":
|
|
||||||
return text
|
|
||||||
elif normalize == "m2met":
|
|
||||||
import re
|
|
||||||
|
|
||||||
text = text.replace(" ", "")
|
|
||||||
text = text.replace("<sil>", "")
|
|
||||||
text = text.replace("<%>", "")
|
|
||||||
text = text.replace("<->", "")
|
|
||||||
text = text.replace("<$>", "")
|
|
||||||
text = text.replace("<#>", "")
|
|
||||||
text = text.replace("<_>", "")
|
|
||||||
text = text.replace("<space>", "")
|
|
||||||
text = text.replace("`", "")
|
|
||||||
text = text.replace("&", "")
|
|
||||||
text = text.replace(",", "")
|
|
||||||
if re.search("[a-zA-Z]", text):
|
|
||||||
text = text.upper()
|
|
||||||
text = text.replace("A", "A")
|
|
||||||
text = text.replace("a", "A")
|
|
||||||
text = text.replace("b", "B")
|
|
||||||
text = text.replace("c", "C")
|
|
||||||
text = text.replace("k", "K")
|
|
||||||
text = text.replace("t", "T")
|
|
||||||
text = text.replace(",", "")
|
|
||||||
text = text.replace("丶", "")
|
|
||||||
text = text.replace("。", "")
|
|
||||||
text = text.replace("、", "")
|
|
||||||
text = text.replace("?", "")
|
|
||||||
return text
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
@ -406,6 +369,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
texts = [list("".join(text.split())) for text in texts]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
@ -418,12 +382,8 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_text = normalize_text_alimeeting(ref_text)
|
this_batch.append((cut_id, ref_text, hyp_text))
|
||||||
ref_words = ref_text.split()
|
|
||||||
print(f"ref: {ref_text}")
|
|
||||||
print(f"hyp: {''.join(hyp_words)}")
|
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -439,40 +399,38 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
):
|
):
|
||||||
|
|
||||||
enable_log = True
|
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
if enable_log:
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out CERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
# we compute CER for aishell dataset.
|
|
||||||
results_char = []
|
|
||||||
for res in results:
|
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
if enable_log:
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
errs_info = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -495,9 +453,13 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / f"{params.method}"
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
setup_logger(
|
setup_logger(
|
||||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
params.res_dir
|
||||||
|
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -574,23 +536,20 @@ def main():
|
|||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
assert start >= 1, start
|
assert start >= 1, start
|
||||||
checkpoint = torch.load(
|
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
|
||||||
)
|
|
||||||
assert "model" not in checkpoint
|
|
||||||
# deepspeed converted checkpoint only contains model state_dict
|
# deepspeed converted checkpoint only contains model state_dict
|
||||||
filenames = [
|
filenames = [
|
||||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
|
||||||
for epoch in range(start, params.epoch + 1)
|
for epoch in range(start, params.epoch + 1)
|
||||||
]
|
]
|
||||||
avg_checkpoint = average_checkpoints(filenames)
|
avg_checkpoint = average_checkpoints(filenames)
|
||||||
model.load_state_dict(avg_checkpoint, strict=False)
|
model.load_state_dict(avg_checkpoint, strict=False)
|
||||||
|
|
||||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
# filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
torch.save(avg_checkpoint, filename)
|
# torch.save(avg_checkpoint, filename)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
|
||||||
|
map_location="cpu",
|
||||||
)
|
)
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
|
||||||
@ -643,8 +602,7 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
main()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
# 2024 Yuekai Zhang
|
# 2024 Yuekai Zhang
|
||||||
|
# 2025 Yifan Yang
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from typing import Dict, Optional, Tuple
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import deepspeed
|
import deepspeed
|
||||||
import k2
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
import whisper
|
import whisper
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
from label_smoothing import LabelSmoothingLoss
|
|
||||||
from lhotse import CutSet, load_manifest
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
from peft import LoraConfig, get_peft_model
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall import diagnostics
|
|
||||||
from icefall.dist import get_rank, get_world_size
|
from icefall.dist import get_rank, get_world_size
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
AttributeDict,
|
|
||||||
MetricsTracker,
|
|
||||||
filter_uneven_sized_batch,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
|
|
||||||
@ -286,13 +272,6 @@ def compute_loss(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tuple of two elements. The first element is the loss tensor.
|
Return a tuple of two elements. The first element is the loss tensor.
|
||||||
"""
|
"""
|
||||||
# For the uneven-sized batch, the total duration after padding would possibly
|
|
||||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
|
||||||
# we simply drop the last few shortest samples, so that the retained total frames
|
|
||||||
# (after padding) would not exceed `allowed_max_frames`:
|
|
||||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
|
||||||
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
|
||||||
# We set allowed_excess_duration_ratio=0.1.
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
@ -347,46 +326,6 @@ def compute_loss(
|
|||||||
|
|
||||||
return input_ids, attention_mask, target_ids
|
return input_ids, attention_mask, target_ids
|
||||||
|
|
||||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
|
||||||
"""
|
|
||||||
Text normalization similar to M2MeT challenge baseline.
|
|
||||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
|
||||||
"""
|
|
||||||
if normalize == "none":
|
|
||||||
return text
|
|
||||||
elif normalize == "m2met":
|
|
||||||
import re
|
|
||||||
|
|
||||||
text = text.replace(" ", "")
|
|
||||||
text = text.replace("<sil>", "")
|
|
||||||
text = text.replace("<%>", "")
|
|
||||||
text = text.replace("<->", "")
|
|
||||||
text = text.replace("<$>", "")
|
|
||||||
text = text.replace("<#>", "")
|
|
||||||
text = text.replace("<_>", "")
|
|
||||||
text = text.replace("<space>", "")
|
|
||||||
text = text.replace("`", "")
|
|
||||||
text = text.replace("&", "")
|
|
||||||
text = text.replace(",", "")
|
|
||||||
if re.search("[a-zA-Z]", text):
|
|
||||||
text = text.upper()
|
|
||||||
text = text.replace("A", "A")
|
|
||||||
text = text.replace("a", "A")
|
|
||||||
text = text.replace("b", "B")
|
|
||||||
text = text.replace("c", "C")
|
|
||||||
text = text.replace("k", "K")
|
|
||||||
text = text.replace("t", "T")
|
|
||||||
text = text.replace(",", "")
|
|
||||||
text = text.replace("丶", "")
|
|
||||||
text = text.replace("。", "")
|
|
||||||
text = text.replace("、", "")
|
|
||||||
text = text.replace("?", "")
|
|
||||||
return text
|
|
||||||
|
|
||||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
|
||||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
|
||||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
|
||||||
@ -397,11 +336,10 @@ def compute_loss(
|
|||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
# remove spaces in texts
|
|
||||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
|
text = text.replace(" ", "")
|
||||||
message = [
|
message = [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||||
{"role": "assistant", "content": text},
|
{"role": "assistant", "content": text},
|
||||||
@ -516,14 +454,17 @@ def train_one_epoch(
|
|||||||
The rank of the node in DDP training. If no DDP is used, it should
|
The rank of the node in DDP training. If no DDP is used, it should
|
||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
model.encoder_projector.train()
|
model.train()
|
||||||
|
model.encoder.eval()
|
||||||
|
if not params.unfreeze_llm:
|
||||||
|
model.llm.eval()
|
||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -533,6 +474,9 @@ def train_one_epoch(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
|
model.encoder.eval()
|
||||||
|
if not params.unfreeze_llm:
|
||||||
|
model.llm.eval()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
@ -648,7 +592,6 @@ def run(rank, world_size, args):
|
|||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
for name, param in speech_encoder.named_parameters():
|
for name, param in speech_encoder.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
speech_encoder.eval()
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
@ -671,7 +614,6 @@ def run(rank, world_size, args):
|
|||||||
if not params.unfreeze_llm:
|
if not params.unfreeze_llm:
|
||||||
for name, param in llm.named_parameters():
|
for name, param in llm.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
llm.eval()
|
|
||||||
else:
|
else:
|
||||||
if params.use_lora:
|
if params.use_lora:
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
@ -728,7 +670,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
assert params.deepspeed and world_size > 1
|
assert params.deepspeed
|
||||||
logging.info("Using DeepSpeed")
|
logging.info("Using DeepSpeed")
|
||||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||||
args=params, model=model, model_parameters=model.parameters()
|
args=params, model=model, model_parameters=model.parameters()
|
||||||
@ -764,7 +706,7 @@ def run(rank, world_size, args):
|
|||||||
if params.sampler_state_dict_path:
|
if params.sampler_state_dict_path:
|
||||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||||
sampler_state_dict["max_duration"] = params.max_duration
|
sampler_state_dict["max_duration"] = params.max_duration
|
||||||
# TODO: load sampler state dict
|
|
||||||
train_dl = data_module.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
@ -806,15 +748,15 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"zero-epoch-{params.cur_epoch}",
|
||||||
client_state={},
|
client_state={},
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}",
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"zero-epoch-{params.cur_epoch}",
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
# save sampler state dict into checkpoint
|
# save sampler state dict into checkpoint
|
||||||
@ -824,7 +766,7 @@ def run(rank, world_size, args):
|
|||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
@ -865,6 +807,7 @@ def main():
|
|||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
run(rank=rank, world_size=world_size, args=args)
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user