mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 02:52:18 +00:00
s2t training
This commit is contained in:
parent
1d11662016
commit
3ad075af60
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=1
|
stage=$1
|
||||||
stop_stage=1
|
stop_stage=$2
|
||||||
# All files generated by this script are saved in "data".
|
# All files generated by this script are saved in "data".
|
||||||
# You can safely remove "data" and rerun this script to regenerate it.
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
mkdir -p data
|
mkdir -p data
|
||||||
@ -20,10 +20,12 @@ log() {
|
|||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "stage 0: "
|
log "stage 0: "
|
||||||
|
pip uninstall lhotse
|
||||||
cd /workspace/slam/lhotse
|
cd /workspace/slam/lhotse
|
||||||
git config --global --add safe.directory /workspace/slam/lhotse
|
git config --global --add safe.directory /workspace/slam/lhotse
|
||||||
pip install -e '.[dev]'
|
pip install -e '.[dev]'
|
||||||
cd -
|
cd -
|
||||||
|
pip install -r slam_omni/requirements.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
@ -32,6 +34,20 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
python3 local/compute_whisper_fbank.py
|
python3 local/compute_whisper_fbank.py
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||||
|
log "Stage 3: Combine features"
|
||||||
|
manifest_dir=data/fbank
|
||||||
|
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
||||||
|
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
||||||
|
# # remove cust_belle_00000.jsonl.gz from pieces
|
||||||
|
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
|
||||||
|
echo $pieces | wc
|
||||||
|
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
|
||||||
|
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd -
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "stage 2: "
|
log "stage 2: "
|
||||||
python3 ./slam_omni/decode.py \
|
python3 ./slam_omni/decode.py \
|
||||||
@ -46,17 +62,21 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
ngpu=2
|
ngpu=8
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "stage 3: "
|
log "stage 3: "
|
||||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 80 \
|
||||||
--exp-dir ./slam_omni/exp_test \
|
--enable-musan False \
|
||||||
|
--exp-dir ./slam_omni/exp_speech2text \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
|
--pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||||
|
--sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
|
||||||
--use-lora True --unfreeze-llm True
|
--use-lora True --unfreeze-llm True
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -357,26 +357,20 @@ class AsrDataModule:
|
|||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
transforms = []
|
"""
|
||||||
# if self.args.concatenate_cuts:
|
Args:
|
||||||
# transforms = [
|
cuts_valid:
|
||||||
# CutConcatenate(
|
CutSet for validation.
|
||||||
# duration_factor=self.args.duration_factor, gap=self.args.gap
|
"""
|
||||||
# )
|
|
||||||
# ] + transforms
|
|
||||||
|
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
|
||||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
|
if self.args.on_the_fly_feats
|
||||||
return_cuts=self.args.return_cuts,
|
else eval(self.args.input_strategy)(),
|
||||||
)
|
|
||||||
else:
|
|
||||||
validate = K2SpeechRecognitionDataset(
|
|
||||||
cut_transforms=transforms,
|
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
@ -435,3 +429,17 @@ class AsrDataModule:
|
|||||||
return {'test':cut_set}
|
return {'test':cut_set}
|
||||||
else:
|
else:
|
||||||
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test cuts")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz")
|
@ -206,13 +206,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-aishell",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to only use aishell1 dataset for training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
@ -297,13 +290,11 @@ def compute_loss(
|
|||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
max_len: int,
|
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Preprocesses the data for supervised fine-tuning."""
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
texts = []
|
texts = []
|
||||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
print(msg,23333333333333)
|
|
||||||
texts.append(
|
texts.append(
|
||||||
tokenizer.apply_chat_template(
|
tokenizer.apply_chat_template(
|
||||||
msg,
|
msg,
|
||||||
@ -311,11 +302,16 @@ def compute_loss(
|
|||||||
chat_template=TEMPLATE,
|
chat_template=TEMPLATE,
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
padding="longest", # FIX me change padding to longest
|
padding="longest", # FIX me change padding to longest
|
||||||
max_length=max_len,
|
truncation=False,
|
||||||
truncation=True,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||||
|
# remove too long text
|
||||||
|
texts = [ text for text in texts if len(text) < 1024 ]
|
||||||
|
if len(texts) != len(messages):
|
||||||
|
logging.warning(
|
||||||
|
f"Remove too long text, {messages} "
|
||||||
|
)
|
||||||
max_len_texts = max([len(text) for text in texts])
|
max_len_texts = max([len(text) for text in texts])
|
||||||
if tokenizer.padding_side == "right":
|
if tokenizer.padding_side == "right":
|
||||||
texts = [
|
texts = [
|
||||||
@ -336,18 +332,14 @@ def compute_loss(
|
|||||||
mask_prompt = True
|
mask_prompt = True
|
||||||
if mask_prompt:
|
if mask_prompt:
|
||||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||||
default_speech_token_indices = torch.where(
|
mask_indices = torch.where(
|
||||||
input_ids == default_speech_token_id
|
input_ids == default_speech_token_id
|
||||||
)
|
)
|
||||||
mask_indices = torch.where(
|
|
||||||
input_ids == tokenizer.convert_tokens_to_ids("assistant")
|
|
||||||
)
|
|
||||||
print(mask_indices, default_speech_token_indices, default_speech_token_id)
|
|
||||||
for i in range(mask_indices[0].size(0)):
|
for i in range(mask_indices[0].size(0)):
|
||||||
row = mask_indices[0][i]
|
row = mask_indices[0][i]
|
||||||
col = mask_indices[1][i]
|
col = mask_indices[1][i]
|
||||||
# + 2 to skip: 'assistant', '\n'
|
# + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||||
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
|
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||||
|
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
@ -380,6 +372,7 @@ def compute_loss(
|
|||||||
message = []
|
message = []
|
||||||
if total_round > 1:
|
if total_round > 1:
|
||||||
history_question_answer = history_contexts[i].split('USER:')
|
history_question_answer = history_contexts[i].split('USER:')
|
||||||
|
history_question_answer = [item for item in history_question_answer if item]
|
||||||
for j in range(total_round - 1):
|
for j in range(total_round - 1):
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
question_answer = history_question_answer[j].split('ASSISTANT:')
|
||||||
@ -393,7 +386,7 @@ def compute_loss(
|
|||||||
]
|
]
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
|
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
||||||
|
|
||||||
target_ids = target_ids.type(torch.LongTensor)
|
target_ids = target_ids.type(torch.LongTensor)
|
||||||
input_ids = input_ids.type(torch.LongTensor)
|
input_ids = input_ids.type(torch.LongTensor)
|
||||||
@ -508,7 +501,7 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -720,7 +713,6 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_module = AsrDataModule(args)
|
data_module = AsrDataModule(args)
|
||||||
multi_dataset = MultiDataset(args.manifest_dir)
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -738,10 +730,8 @@ def run(rank, world_size, args):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if params.use_aishell:
|
|
||||||
train_cuts = multi_dataset.aishell_train_cuts()
|
train_cuts = data_module.train_cuts()
|
||||||
else:
|
|
||||||
train_cuts = multi_dataset.train_cuts()
|
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
@ -754,10 +744,7 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.use_aishell:
|
valid_cuts = data_module.dev_cuts()
|
||||||
valid_cuts = multi_dataset.aishell_dev_cuts()
|
|
||||||
else:
|
|
||||||
valid_cuts = multi_dataset.dev_cuts()
|
|
||||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user