diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index e2eb77c6b..30f8ba76d 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -206,10 +206,11 @@ def get_parser(): ) parser.add_argument( - "--use-aishell", - type=str2bool, - default=True, - help="Whether to only use aishell1 dataset for training.", + "--dataset", + type=str, + default="aishell", + choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"], + help="The dataset to decode", ) add_model_arguments(parser) @@ -540,7 +541,7 @@ def main(): if params.avg > 1: - start = params.epoch - params.avg + start = params.epoch - params.avg + 1 assert start >= 1, start checkpoint = torch.load( f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" @@ -551,18 +552,17 @@ def main(): f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1) ] - model.load_state_dict(average_checkpoints(filenames), strict=False) + avg_checkpoint = average_checkpoints(filenames) + model.load_state_dict(avg_checkpoint, strict=False) filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(model.state_dict(), filename) + torch.save(avg_checkpoint, filename) else: checkpoint = torch.load( f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" ) - if "model" not in checkpoint: - model.load_state_dict(checkpoint, strict=False) - else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.load_state_dict(checkpoint, strict=False) + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) @@ -584,11 +584,14 @@ def main(): return False return True - if params.use_aishell: + if params.dataset == "aishell": test_sets_cuts = multi_dataset.aishell_test_cuts() - else: - # test_sets_cuts = multi_dataset.test_cuts() + elif params.dataset == "speechio": + test_sets_cuts = multi_dataset.speechio_test_cuts() + elif params.dataaset == "wenetspeech_test_meeting": test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts() + else: + test_sets_cuts = multi_dataset.test_cuts() test_sets = test_sets_cuts.keys() test_dls = [ diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index 9b3ef6e69..abfa41b3f 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -330,4 +330,26 @@ class MultiDataset: return { "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, - } \ No newline at end of file + } + + def speechio_test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + start_index = 0 + end_index = 26 + dataset_parts = [] + for i in range(start_index, end_index + 1): + idx = f"{i}".zfill(2) + dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") + + prefix = "speechio" + suffix = "jsonl.gz" + + results_dict = {} + for partition in dataset_parts: + path = f"{prefix}_cuts_{partition}.{suffix}" + + logging.info(f"Loading {path} set in lazy mode") + test_cuts = load_manifest_lazy(self.fbank_dir / path) + results_dict[partition] = test_cuts + + return results_dict \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 99b8dae0b..0815b6d3a 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -751,7 +751,6 @@ def run(rank, world_size, args): if params.pretrained_model_path: checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) - assert len(unexpected_keys) == 0, unexpected_keys num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")