make asr decode results align

This commit is contained in:
root 2025-02-26 07:08:34 +00:00
parent cca562d538
commit e6897b10fa
4 changed files with 17 additions and 17 deletions

View File

@ -40,6 +40,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --use-lora False # --on-the-fly-feats True
fi fi

View File

@ -39,9 +39,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
) )
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from datasets import load_dataset
from icefall.utils import str2bool from icefall.utils import str2bool
from speech_dataset import K2SpeechRecognitionDataset from speech_dataset import K2SpeechRecognitionDataset
class _SeedWorkers: class _SeedWorkers:
@ -396,7 +396,7 @@ class AsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))) input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu')))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
@ -419,19 +419,19 @@ class AsrDataModule:
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get test cuts")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition) # dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition)
i, num_digits = 0, 5 i, num_digits = 0, 5
idx = f"{i}".zfill(num_digits) idx = f"{i}".zfill(num_digits)
parquet_files = [ parquet_files = [
f"data/train-{idx}-of-01601.parquet", f"data/train-{idx}-of-01601.parquet",
] ]
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files] parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
file_name = parquet_files[0] file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}") logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train') dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key) cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key)
if args.resample_to_16kHz: if self.args.resample_to_16kHz:
cut_set = cut_set.resample(16000) cut_set = cut_set.resample(16000)
return cut_set return {'test':cut_set}
else: else:
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz") return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}

View File

@ -318,9 +318,9 @@ def decode_one_batch(
2, 2,
) )
supervisions = batch["supervisions"] # supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"] # feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype) # feature_len = feature_len.to(device, dtype=dtype)
messages = [ messages = [
[ [
@ -336,9 +336,6 @@ def decode_one_batch(
) )
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(hyps)
print(supervisions)
return {"beam-search": hyps} return {"beam-search": hyps}
@ -408,7 +405,10 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] answers = batch["supervisions"]["text"]
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
texts = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(

View File

@ -66,7 +66,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset # from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter