s2t training

This commit is contained in:
root 2025-04-15 02:16:03 +00:00
parent 1d11662016
commit 3ad075af60
3 changed files with 70 additions and 55 deletions

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
set -eou pipefail
stage=1
stop_stage=1
stage=$1
stop_stage=$2
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
@ -20,10 +20,12 @@ log() {
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: "
pip uninstall lhotse
cd /workspace/slam/lhotse
git config --global --add safe.directory /workspace/slam/lhotse
pip install -e '.[dev]'
cd -
pip install -r slam_omni/requirements.txt
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
@ -32,6 +34,20 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
python3 local/compute_whisper_fbank.py
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 3: Combine features"
manifest_dir=data/fbank
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
# # remove cust_belle_00000.jsonl.gz from pieces
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
echo $pieces | wc
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd -
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: "
python3 ./slam_omni/decode.py \
@ -46,17 +62,21 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
ngpu=2
ngpu=8
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: "
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 200 \
--exp-dir ./slam_omni/exp_test \
--max-duration 80 \
--enable-musan False \
--exp-dir ./slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
--sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
--use-lora True --unfreeze-llm True
fi
fi

View File

@ -357,26 +357,20 @@ class AsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
# if self.args.concatenate_cuts:
# transforms = [
# CutConcatenate(
# duration_factor=self.args.duration_factor, gap=self.args.gap
# )
# ] + transforms
"""
Args:
cuts_valid:
CutSet for validation.
"""
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
@ -434,4 +428,18 @@ class AsrDataModule:
cut_set = cut_set.resample(16000)
return {'test':cut_set}
else:
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
pass
else:
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz")

View File

@ -206,13 +206,6 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-aishell",
type=str2bool,
default=True,
help="Whether to only use aishell1 dataset for training.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
@ -297,13 +290,11 @@ def compute_loss(
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
print(msg,23333333333333)
texts.append(
tokenizer.apply_chat_template(
msg,
@ -311,11 +302,16 @@ def compute_loss(
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
max_length=max_len,
truncation=True,
truncation=False,
)
)
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
# remove too long text
texts = [ text for text in texts if len(text) < 1024 ]
if len(texts) != len(messages):
logging.warning(
f"Remove too long text, {messages} "
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
@ -336,18 +332,14 @@ def compute_loss(
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
default_speech_token_indices = torch.where(
mask_indices = torch.where(
input_ids == default_speech_token_id
)
mask_indices = torch.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)
print(mask_indices, default_speech_token_indices, default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
# + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
@ -380,6 +372,7 @@ def compute_loss(
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split('USER:')
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:')
@ -393,7 +386,7 @@ def compute_loss(
]
messages.append(message)
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
@ -508,7 +501,7 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -720,7 +713,6 @@ def run(rank, world_size, args):
)
data_module = AsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -738,10 +730,8 @@ def run(rank, world_size, args):
return False
return True
if params.use_aishell:
train_cuts = multi_dataset.aishell_train_cuts()
else:
train_cuts = multi_dataset.train_cuts()
train_cuts = data_module.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -754,10 +744,7 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
if params.use_aishell:
valid_cuts = multi_dataset.aishell_dev_cuts()
else:
valid_cuts = multi_dataset.dev_cuts()
valid_cuts = data_module.dev_cuts()
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0: