mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Modified aishell/ASR/conformer_ctc/decode.py,asr_datamodule.py for batch-way decoding, faster.
This commit is contained in:
parent
19ce1a4fb4
commit
448c28b3cc
@ -366,13 +366,14 @@ def decode_dataset(
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
# try:
|
||||
# num_batches = len(dl)
|
||||
# except TypeError:
|
||||
# num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
batch = batch[0]
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
@ -399,9 +400,8 @@ def decode_dataset(
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
# batch_str = f"{batch_idx}/{num_batches}"
|
||||
logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -547,20 +547,19 @@ def main():
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
# for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
lexicon=lexicon,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
lexicon=lexicon,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from lhotse.cut import MonoCut
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
@ -180,7 +181,34 @@ class AishellAsrDataModule:
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
def to_dict(self, obj):
|
||||
"""
|
||||
Recursively convert an object and its nested objects to dictionaries.
|
||||
"""
|
||||
if isinstance(obj, (str, int, float, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, list):
|
||||
return [to_dict(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {key: to_dict(value) for key, value in obj.items()}
|
||||
elif hasattr(obj, '__dict__'):
|
||||
return {key: to_dict(value) for key, value in obj.__dict__.items()}
|
||||
else:
|
||||
raise TypeError(f"Unsupported type: {type(obj)}")
|
||||
|
||||
def my_collate_fn(self, batch):
|
||||
"""
|
||||
Convert MonoCut to dict.
|
||||
"""
|
||||
return_batch = []
|
||||
for item in batch:
|
||||
if isinstance(item, MonoCut):
|
||||
processed_item = self.to_dict(item)
|
||||
return_batch.append(processed_item)
|
||||
elif isinstance(item, dict):
|
||||
return_batch.append(item)
|
||||
return return_batch
|
||||
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
) -> DataLoader:
|
||||
@ -354,9 +382,10 @@ class AishellAsrDataModule:
|
||||
)
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
batch_size=100, # specified to some value
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck
|
||||
collate_fn=self.my_collate_fn
|
||||
)
|
||||
return test_dl
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user