diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json index 730937a21..9d2cef08f 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json @@ -19,14 +19,14 @@ "optimizer": { "type": "Adam", "params": { - "lr": 1e-4 + "lr": 5e-4 } }, "scheduler": { "type": "WarmupLR", "params": { "warmup_min_lr": 0, - "warmup_max_lr": 1e-4, + "warmup_max_lr": 5e-4, "warmup_num_steps": 100 } }, diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index 3fc7c654b..b306b7fe5 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -6,7 +6,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index class EncoderProjector(nn.Module): # https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py - def __init__(self, encoder_dim, llm_dim, downsample_rate=4): + def __init__(self, encoder_dim, llm_dim, downsample_rate=1): super().__init__() self.downsample_rate = downsample_rate self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim) @@ -140,13 +140,16 @@ class SPEECH_LLM(nn.Module): speech_features = self.encoder_projector(encoder_outs) inputs_embeds = self.llm.get_input_embeddings()(input_ids) - #print("input_ids", input_ids, input_ids.shape) - #print("labels", labels, labels.shape) + # print("input_ids", input_ids, input_ids.shape) + # print("labels", labels, labels.shape) + # print("inputs_embeds", inputs_embeds.shape, inputs_embeds) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) - #print("labels", labels, labels.shape) - #print("speech_features", speech_features.shape) + # print("labels", labels, labels.shape) + # print("speech_features", speech_features.shape, speech_features) + # print("inputs_embeds after", inputs_embeds.shape, inputs_embeds) + # input() model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) with torch.no_grad(): diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index e4b148ea5..52bbf8c64 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -280,4 +280,41 @@ class MultiDataset: return { "aishell_test": aishell_test_cuts, + } + + + # aishell 2 + def aishell2_train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # AISHELL-2 + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + return aishell_2_cuts + + def aishell2_dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # AISHELL-2 + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + return aishell2_dev_cuts + + def aishell2_test_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + + # AISHELL-2 + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + + return { + "aishell2_test": aishell2_test_cuts, } \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 9b650e747..43bab3491 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -820,7 +820,8 @@ def run(rank, world_size, args): return True # train_cuts = multi_dataset.train_cuts() - train_cuts = multi_dataset.aishell_train_cuts() + # train_cuts = multi_dataset.aishell_train_cuts() + train_cuts = multi_dataset.aishell2_train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt) # if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: