diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 6c7393379..e23f26684 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni set -eou pipefail -stage=1 -stop_stage=1 +stage=$1 +stop_stage=$2 # All files generated by this script are saved in "data". # You can safely remove "data" and rerun this script to regenerate it. mkdir -p data @@ -20,10 +20,12 @@ log() { if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "stage 0: " + pip uninstall lhotse cd /workspace/slam/lhotse git config --global --add safe.directory /workspace/slam/lhotse pip install -e '.[dev]' cd - + pip install -r slam_omni/requirements.txt fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then @@ -32,6 +34,20 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then python3 local/compute_whisper_fbank.py fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 3: Combine features" + manifest_dir=data/fbank + if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then + pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort) + # # remove cust_belle_00000.jsonl.gz from pieces + # pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g') + echo $pieces | wc + lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz + cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd - + fi +fi + if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "stage 2: " python3 ./slam_omni/decode.py \ @@ -46,17 +62,21 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi -ngpu=2 +ngpu=8 if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "stage 3: " torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ - --max-duration 200 \ - --exp-dir ./slam_omni/exp_test \ + --max-duration 80 \ + --enable-musan False \ + --exp-dir ./slam_omni/exp_speech2text \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --manifest-dir data/fbank \ --deepspeed \ --deepspeed_config ./slam_omni/ds_config_zero1.json \ --use-flash-attn True \ + --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ + --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \ --use-lora True --unfreeze-llm True -fi \ No newline at end of file +fi + diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py index a8b1a4746..11e3bc779 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py @@ -357,26 +357,20 @@ class AsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - # if self.args.concatenate_cuts: - # transforms = [ - # CutConcatenate( - # duration_factor=self.args.duration_factor, gap=self.args.gap - # ) - # ] + transforms - + """ + Args: + cuts_valid: + CutSet for validation. + """ logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) + + validate = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, @@ -434,4 +428,18 @@ class AsrDataModule: cut_set = cut_set.resample(16000) return {'test':cut_set} else: - return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")} \ No newline at end of file + return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")} + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get test cuts") + if self.args.on_the_fly_feats: + pass + else: + return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz") + + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_belle_train.jsonl.gz") \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index 7fc207455..f0df303e4 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -206,13 +206,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--use-aishell", - type=str2bool, - default=True, - help="Whether to only use aishell1 dataset for training.", - ) - parser = deepspeed.add_config_arguments(parser) add_model_arguments(parser) @@ -297,13 +290,11 @@ def compute_loss( def preprocess( messages, tokenizer: transformers.PreTrainedTokenizer, - max_len: int, ) -> Dict: """Preprocesses the data for supervised fine-tuning.""" texts = [] TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" for i, msg in enumerate(messages): - print(msg,23333333333333) texts.append( tokenizer.apply_chat_template( msg, @@ -311,11 +302,16 @@ def compute_loss( chat_template=TEMPLATE, add_generation_prompt=False, padding="longest", # FIX me change padding to longest - max_length=max_len, - truncation=True, + truncation=False, ) ) # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id + # remove too long text + texts = [ text for text in texts if len(text) < 1024 ] + if len(texts) != len(messages): + logging.warning( + f"Remove too long text, {messages} " + ) max_len_texts = max([len(text) for text in texts]) if tokenizer.padding_side == "right": texts = [ @@ -336,18 +332,14 @@ def compute_loss( mask_prompt = True if mask_prompt: default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) - default_speech_token_indices = torch.where( + mask_indices = torch.where( input_ids == default_speech_token_id ) - mask_indices = torch.where( - input_ids == tokenizer.convert_tokens_to_ids("assistant") - ) - print(mask_indices, default_speech_token_indices, default_speech_token_id) for i in range(mask_indices[0].size(0)): row = mask_indices[0][i] col = mask_indices[1][i] - # + 2 to skip: 'assistant', '\n' - target_ids[row, : col + 2] = IGNORE_TOKEN_ID + # + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198 + target_ids[row, : col + 6] = IGNORE_TOKEN_ID attention_mask = input_ids.ne(tokenizer.pad_token_id) @@ -380,6 +372,7 @@ def compute_loss( message = [] if total_round > 1: history_question_answer = history_contexts[i].split('USER:') + history_question_answer = [item for item in history_question_answer if item] for j in range(total_round - 1): # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 question_answer = history_question_answer[j].split('ASSISTANT:') @@ -393,7 +386,7 @@ def compute_loss( ] messages.append(message) - input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128) + input_ids, attention_mask, target_ids = preprocess(messages, tokenizer) target_ids = target_ids.type(torch.LongTensor) input_ids = input_ids.type(torch.LongTensor) @@ -508,7 +501,7 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -720,7 +713,6 @@ def run(rank, world_size, args): ) data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir) def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -738,10 +730,8 @@ def run(rank, world_size, args): return False return True - if params.use_aishell: - train_cuts = multi_dataset.aishell_train_cuts() - else: - train_cuts = multi_dataset.train_cuts() + + train_cuts = data_module.train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -754,10 +744,7 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - if params.use_aishell: - valid_cuts = multi_dataset.aishell_dev_cuts() - else: - valid_cuts = multi_dataset.dev_cuts() + valid_cuts = data_module.dev_cuts() valid_dl = data_module.valid_dataloaders(valid_cuts) if args.tensorboard and rank == 0: