mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
make asr decode results align
This commit is contained in:
parent
cca562d538
commit
e6897b10fa
@ -40,6 +40,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True
|
||||
--use-lora False # --on-the-fly-feats True
|
||||
|
||||
fi
|
||||
|
@ -39,9 +39,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from speech_dataset import K2SpeechRecognitionDataset
|
||||
|
||||
class _SeedWorkers:
|
||||
@ -396,7 +396,7 @@ class AsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
|
||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu')))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
@ -419,19 +419,19 @@ class AsrDataModule:
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
if self.args.on_the_fly_feats:
|
||||
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
||||
# dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
||||
i, num_digits = 0, 5
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
parquet_files = [
|
||||
f"data/train-{idx}-of-01601.parquet",
|
||||
]
|
||||
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
||||
parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
||||
file_name = parquet_files[0]
|
||||
logging.info(f"Loading dataset from {file_name}")
|
||||
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
|
||||
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key)
|
||||
if args.resample_to_16kHz:
|
||||
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key)
|
||||
if self.args.resample_to_16kHz:
|
||||
cut_set = cut_set.resample(16000)
|
||||
return cut_set
|
||||
return {'test':cut_set}
|
||||
else:
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
|
||||
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
@ -318,9 +318,9 @@ def decode_one_batch(
|
||||
2,
|
||||
)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
# supervisions = batch["supervisions"]
|
||||
# feature_len = supervisions["num_frames"]
|
||||
# feature_len = feature_len.to(device, dtype=dtype)
|
||||
|
||||
messages = [
|
||||
[
|
||||
@ -336,9 +336,6 @@ def decode_one_batch(
|
||||
)
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
print(hyps)
|
||||
print(supervisions)
|
||||
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
@ -408,7 +405,10 @@ def decode_dataset(
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
answers = batch["supervisions"]["text"]
|
||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
||||
texts = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
|
@ -66,7 +66,7 @@ from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
from multi_dataset import MultiDataset
|
||||
# from multi_dataset import MultiDataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
Loading…
x
Reference in New Issue
Block a user