mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
s2t training
This commit is contained in:
parent
1d11662016
commit
3ad075af60
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
||||
set -eou pipefail
|
||||
|
||||
stage=1
|
||||
stop_stage=1
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
@ -20,10 +20,12 @@ log() {
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: "
|
||||
pip uninstall lhotse
|
||||
cd /workspace/slam/lhotse
|
||||
git config --global --add safe.directory /workspace/slam/lhotse
|
||||
pip install -e '.[dev]'
|
||||
cd -
|
||||
pip install -r slam_omni/requirements.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
@ -32,6 +34,20 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
python3 local/compute_whisper_fbank.py
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 3: Combine features"
|
||||
manifest_dir=data/fbank
|
||||
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
||||
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
||||
# # remove cust_belle_00000.jsonl.gz from pieces
|
||||
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
|
||||
echo $pieces | wc
|
||||
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
|
||||
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd -
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "stage 2: "
|
||||
python3 ./slam_omni/decode.py \
|
||||
@ -46,17 +62,21 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
|
||||
fi
|
||||
|
||||
ngpu=2
|
||||
ngpu=8
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: "
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./slam_omni/exp_test \
|
||||
--max-duration 80 \
|
||||
--enable-musan False \
|
||||
--exp-dir ./slam_omni/exp_speech2text \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
--sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
|
||||
--use-lora True --unfreeze-llm True
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -357,26 +357,20 @@ class AsrDataModule:
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
# if self.args.concatenate_cuts:
|
||||
# transforms = [
|
||||
# CutConcatenate(
|
||||
# duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
# )
|
||||
# ] + transforms
|
||||
|
||||
"""
|
||||
Args:
|
||||
cuts_valid:
|
||||
CutSet for validation.
|
||||
"""
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
@ -434,4 +428,18 @@ class AsrDataModule:
|
||||
cut_set = cut_set.resample(16000)
|
||||
return {'test':cut_set}
|
||||
else:
|
||||
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
||||
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
if self.args.on_the_fly_feats:
|
||||
pass
|
||||
else:
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz")
|
@ -206,13 +206,6 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-aishell",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to only use aishell1 dataset for training.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
add_model_arguments(parser)
|
||||
|
||||
@ -297,13 +290,11 @@ def compute_loss(
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_len: int,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
print(msg,23333333333333)
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
@ -311,11 +302,16 @@ def compute_loss(
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="longest", # FIX me change padding to longest
|
||||
max_length=max_len,
|
||||
truncation=True,
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||
# remove too long text
|
||||
texts = [ text for text in texts if len(text) < 1024 ]
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(
|
||||
f"Remove too long text, {messages} "
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
@ -336,18 +332,14 @@ def compute_loss(
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
default_speech_token_indices = torch.where(
|
||||
mask_indices = torch.where(
|
||||
input_ids == default_speech_token_id
|
||||
)
|
||||
mask_indices = torch.where(
|
||||
input_ids == tokenizer.convert_tokens_to_ids("assistant")
|
||||
)
|
||||
print(mask_indices, default_speech_token_indices, default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 2 to skip: 'assistant', '\n'
|
||||
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
|
||||
# + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
@ -380,6 +372,7 @@ def compute_loss(
|
||||
message = []
|
||||
if total_round > 1:
|
||||
history_question_answer = history_contexts[i].split('USER:')
|
||||
history_question_answer = [item for item in history_question_answer if item]
|
||||
for j in range(total_round - 1):
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
||||
@ -393,7 +386,7 @@ def compute_loss(
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
|
||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
||||
|
||||
target_ids = target_ids.type(torch.LongTensor)
|
||||
input_ids = input_ids.type(torch.LongTensor)
|
||||
@ -508,7 +501,7 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -720,7 +713,6 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
data_module = AsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir)
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -738,10 +730,8 @@ def run(rank, world_size, args):
|
||||
return False
|
||||
return True
|
||||
|
||||
if params.use_aishell:
|
||||
train_cuts = multi_dataset.aishell_train_cuts()
|
||||
else:
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
|
||||
train_cuts = data_module.train_cuts()
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
@ -754,10 +744,7 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
if params.use_aishell:
|
||||
valid_cuts = multi_dataset.aishell_dev_cuts()
|
||||
else:
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_cuts = data_module.dev_cuts()
|
||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user