mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
make asr decode results align
This commit is contained in:
parent
cca562d538
commit
e6897b10fa
@ -40,6 +40,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--use-lora True
|
--use-lora False # --on-the-fly-feats True
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
@ -39,9 +39,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|||||||
)
|
)
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
from speech_dataset import K2SpeechRecognitionDataset
|
from speech_dataset import K2SpeechRecognitionDataset
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
@ -396,7 +396,7 @@ class AsrDataModule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
|
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cpu')))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -419,19 +419,19 @@ class AsrDataModule:
|
|||||||
def test_cuts(self) -> CutSet:
|
def test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
# dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
||||||
i, num_digits = 0, 5
|
i, num_digits = 0, 5
|
||||||
idx = f"{i}".zfill(num_digits)
|
idx = f"{i}".zfill(num_digits)
|
||||||
parquet_files = [
|
parquet_files = [
|
||||||
f"data/train-{idx}-of-01601.parquet",
|
f"data/train-{idx}-of-01601.parquet",
|
||||||
]
|
]
|
||||||
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
parquet_files = [f"{self.args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
|
||||||
file_name = parquet_files[0]
|
file_name = parquet_files[0]
|
||||||
logging.info(f"Loading dataset from {file_name}")
|
logging.info(f"Loading dataset from {file_name}")
|
||||||
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
|
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
|
||||||
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key)
|
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=self.args.audio_key, text_key=self.args.text_key)
|
||||||
if args.resample_to_16kHz:
|
if self.args.resample_to_16kHz:
|
||||||
cut_set = cut_set.resample(16000)
|
cut_set = cut_set.resample(16000)
|
||||||
return cut_set
|
return {'test':cut_set}
|
||||||
else:
|
else:
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")
|
return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
@ -318,9 +318,9 @@ def decode_one_batch(
|
|||||||
2,
|
2,
|
||||||
)
|
)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
# supervisions = batch["supervisions"]
|
||||||
feature_len = supervisions["num_frames"]
|
# feature_len = supervisions["num_frames"]
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
# feature_len = feature_len.to(device, dtype=dtype)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
[
|
[
|
||||||
@ -336,9 +336,6 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
print(hyps)
|
|
||||||
print(supervisions)
|
|
||||||
|
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
|
|
||||||
|
|
||||||
@ -408,7 +405,10 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
answers = batch["supervisions"]["text"]
|
||||||
|
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||||
|
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
||||||
|
texts = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
|
@ -66,7 +66,7 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||||
from multi_dataset import MultiDataset
|
# from multi_dataset import MultiDataset
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
Loading…
x
Reference in New Issue
Block a user