Modified aishell/ASR/conformer_ctc/decode.py,asr_datamodule.py for batch-way decoding, faster.

This commit is contained in:
czl66 2024-12-24 13:57:47 +08:00
parent 19ce1a4fb4
commit 448c28b3cc
2 changed files with 50 additions and 22 deletions

View File

@ -366,13 +366,14 @@ def decode_dataset(
num_cuts = 0 num_cuts = 0
try: # try:
num_batches = len(dl) # num_batches = len(dl)
except TypeError: # except TypeError:
num_batches = "?" # num_batches = "?"
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
batch = batch[0]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
@ -399,9 +400,8 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"]) num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" # batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}")
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results return results
@ -547,20 +547,19 @@ def main():
test_sets = ["test"] test_sets = ["test"]
test_dls = [test_dl] test_dls = [test_dl]
# for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)
for test_set, test_dl in zip(test_sets, test_dls): save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict)
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!") logging.info("Done!")

View File

@ -23,6 +23,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from lhotse.cut import MonoCut
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import (
CutConcatenate, CutConcatenate,
@ -180,7 +181,34 @@ class AishellAsrDataModule:
help="When enabled, select noise from MUSAN and mix it" help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ", "with training dataset. ",
) )
def to_dict(self, obj):
"""
Recursively convert an object and its nested objects to dictionaries.
"""
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, list):
return [to_dict(item) for item in obj]
elif isinstance(obj, dict):
return {key: to_dict(value) for key, value in obj.items()}
elif hasattr(obj, '__dict__'):
return {key: to_dict(value) for key, value in obj.__dict__.items()}
else:
raise TypeError(f"Unsupported type: {type(obj)}")
def my_collate_fn(self, batch):
"""
Convert MonoCut to dict.
"""
return_batch = []
for item in batch:
if isinstance(item, MonoCut):
processed_item = self.to_dict(item)
return_batch.append(processed_item)
elif isinstance(item, dict):
return_batch.append(item)
return return_batch
def train_dataloaders( def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader: ) -> DataLoader:
@ -354,9 +382,10 @@ class AishellAsrDataModule:
) )
test_dl = DataLoader( test_dl = DataLoader(
test, test,
batch_size=None, batch_size=100, # specified to some value
sampler=sampler, sampler=sampler,
num_workers=self.args.num_workers, num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck
collate_fn=self.my_collate_fn
) )
return test_dl return test_dl