From e6897b10fa5b79c9312515702cf2766f7cfc54eb Mon Sep 17 00:00:00 2001 From: root Date: Wed, 26 Feb 2025 07:08:34 +0000 Subject: [PATCH] make asr decode results align --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 2 +- .../SPEECH2SPEECH/slam_omni/data_module.py | 16 ++++++++-------- egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py | 14 +++++++------- egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 87e7cd254..b61241974 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -40,6 +40,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --epoch 999 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --use-lora True + --use-lora False # --on-the-fly-feats True fi diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py index 35d1e3494..a8b1a4746 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py @@ -39,9 +39,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples ) from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader +from datasets import load_dataset from icefall.utils import str2bool - from speech_dataset import K2SpeechRecognitionDataset class _SeedWorkers: @@ -396,7 +396,7 @@ class AsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))) + input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu'))) if self.args.on_the_fly_feats else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, @@ -419,19 +419,19 @@ class AsrDataModule: def test_cuts(self) -> CutSet: logging.info("About to get test cuts") if self.args.on_the_fly_feats: - # dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition) + # dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition) i, num_digits = 0, 5 idx = f"{i}".zfill(num_digits) parquet_files = [ f"data/train-{idx}-of-01601.parquet", ] - parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files] + parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files] file_name = parquet_files[0] logging.info(f"Loading dataset from {file_name}") dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train') - cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key) - if args.resample_to_16kHz: + cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key) + if self.args.resample_to_16kHz: cut_set = cut_set.resample(16000) - return cut_set + return {'test':cut_set} else: - return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz") \ No newline at end of file + return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")} \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index 5f5334142..f878d32e7 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -318,9 +318,9 @@ def decode_one_batch( 2, ) - supervisions = batch["supervisions"] - feature_len = supervisions["num_frames"] - feature_len = feature_len.to(device, dtype=dtype) + # supervisions = batch["supervisions"] + # feature_len = supervisions["num_frames"] + # feature_len = feature_len.to(device, dtype=dtype) messages = [ [ @@ -336,9 +336,6 @@ def decode_one_batch( ) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - print(hyps) - print(supervisions) - return {"beam-search": hyps} @@ -408,7 +405,10 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] + answers = batch["supervisions"]["text"] + questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] + answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] + texts = [question.split(': ')[-1].strip() for question in questions_with_history] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index d9489b1ae..1c3ccd2c6 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -66,7 +66,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector -from multi_dataset import MultiDataset +# from multi_dataset import MultiDataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from torch import Tensor from torch.utils.tensorboard import SummaryWriter