mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Sort results to make it more convenient to compare decoding results (#522)
* Sort result to make it more convenient to compare decoding results * Add cut_id to recognition results * add cut_id to results for all recipes * Fix torch.jit.script * Fix comments * Minor fixes * Fix torch.jit.tracing for Pytorch version before v1.9.0
This commit is contained in:
parent
5149788cb2
commit
5c17255eec
@ -367,6 +367,7 @@ def decode_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -379,8 +380,8 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -405,6 +406,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -528,6 +530,8 @@ def main():
|
|||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from lhotse.dataset.webdataset import export_to_webdataset
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||||
|
|
||||||
dev = "dev"
|
dev = "dev"
|
||||||
|
@ -374,6 +374,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -389,9 +390,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -419,6 +420,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -537,6 +539,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
@ -386,6 +386,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -401,9 +402,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -431,6 +432,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -556,6 +558,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
@ -377,6 +377,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -389,9 +390,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -416,6 +417,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -606,6 +608,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
asr_datamodule = AsrDataModule(args)
|
asr_datamodule = AsrDataModule(args)
|
||||||
aishell = AIShell(manifest_dir=args.manifest_dir)
|
aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
|
@ -241,6 +241,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -253,9 +254,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -278,6 +279,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -365,6 +367,8 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
@ -38,8 +38,8 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
write_error_stats,
|
|
||||||
str2bool,
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -296,6 +296,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -307,9 +308,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -334,6 +335,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -438,6 +440,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
@ -341,6 +341,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -353,9 +354,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -380,6 +381,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -496,6 +498,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
asr_datamodule = AsrDataModule(args)
|
asr_datamodule = AsrDataModule(args)
|
||||||
aishell = AIShell(manifest_dir=args.manifest_dir)
|
aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
|
@ -345,6 +345,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -357,9 +358,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -384,6 +385,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -498,6 +500,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
@ -514,6 +514,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -527,8 +528,8 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -553,6 +554,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -756,6 +758,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aishell2 = AiShell2AsrDataModule(args)
|
aishell2 = AiShell2AsrDataModule(args)
|
||||||
|
|
||||||
valid_cuts = aishell2.valid_cuts()
|
valid_cuts = aishell2.valid_cuts()
|
||||||
|
@ -378,6 +378,7 @@ def decode_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -390,8 +391,8 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -416,6 +417,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -607,6 +609,8 @@ def main():
|
|||||||
c.supervisions[0].text = text_normalize(text)
|
c.supervisions[0].text = text_normalize(text)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
aishell4 = Aishell4AsrDataModule(args)
|
aishell4 = Aishell4AsrDataModule(args)
|
||||||
test_cuts = aishell4.test_cuts()
|
test_cuts = aishell4.test_cuts()
|
||||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||||
|
@ -367,6 +367,7 @@ def decode_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -379,8 +380,8 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -405,6 +406,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -535,6 +537,8 @@ def main():
|
|||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from lhotse.dataset.webdataset import export_to_webdataset
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
alimeeting = AlimeetingAsrDataModule(args)
|
alimeeting = AlimeetingAsrDataModule(args)
|
||||||
|
|
||||||
dev = "eval"
|
dev = "eval"
|
||||||
|
@ -451,6 +451,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -469,9 +470,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
else:
|
else:
|
||||||
@ -512,6 +513,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
results = post_processing(results)
|
results = post_processing(results)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -676,6 +678,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = gigaspeech.dev_cuts()
|
dev_cuts = gigaspeech.dev_cuts()
|
||||||
|
@ -374,6 +374,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -386,9 +387,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -414,6 +415,7 @@ def save_results(
|
|||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = post_processing(results)
|
results = post_processing(results)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -544,6 +546,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = gigaspeech.dev_cuts()
|
dev_cuts = gigaspeech.dev_cuts()
|
||||||
|
@ -525,6 +525,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -544,9 +545,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
else:
|
else:
|
||||||
@ -586,6 +587,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -779,6 +781,8 @@ def main():
|
|||||||
)
|
)
|
||||||
rnn_lm_model.eval()
|
rnn_lm_model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -31,14 +31,13 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
@ -633,6 +632,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -652,9 +652,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
else:
|
else:
|
||||||
@ -694,6 +694,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -956,6 +957,8 @@ def main():
|
|||||||
)
|
)
|
||||||
rnn_lm_model.eval()
|
rnn_lm_model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -449,6 +449,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -466,9 +467,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -496,6 +497,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -661,6 +663,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
# CAUTION: `test_sets` is for displaying only.
|
||||||
# If you want to skip test-clean, you have to skip
|
# If you want to skip test-clean, you have to skip
|
||||||
|
@ -403,6 +403,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -415,9 +416,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -442,6 +443,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -624,6 +626,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -29,6 +29,7 @@ class Stream(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
cut_id: str,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
LOG_EPS: float = math.log(1e-10),
|
LOG_EPS: float = math.log(1e-10),
|
||||||
@ -44,6 +45,7 @@ class Stream(object):
|
|||||||
The device to run this stream.
|
The device to run this stream.
|
||||||
"""
|
"""
|
||||||
self.LOG_EPS = LOG_EPS
|
self.LOG_EPS = LOG_EPS
|
||||||
|
self.cut_id = cut_id
|
||||||
|
|
||||||
# Containing attention caches and convolution caches
|
# Containing attention caches and convolution caches
|
||||||
self.states: Optional[
|
self.states: Optional[
|
||||||
@ -138,6 +140,10 @@ class Stream(object):
|
|||||||
"""Return True if all feature frames are processed."""
|
"""Return True if all feature frames are processed."""
|
||||||
return self._done
|
return self._done
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
return self.cut_id
|
||||||
|
|
||||||
def decoding_result(self) -> List[int]:
|
def decoding_result(self) -> List[int]:
|
||||||
"""Obtain current decoding result."""
|
"""Obtain current decoding result."""
|
||||||
if self.decoding_method == "greedy_search":
|
if self.decoding_method == "greedy_search":
|
||||||
|
@ -74,7 +74,6 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
from lhotse import CutSet
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
|||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
|
from lhotse import CutSet
|
||||||
from stream import Stream
|
from stream import Stream
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -678,6 +678,7 @@ def decode_dataset(
|
|||||||
# Each utterance has a Stream.
|
# Each utterance has a Stream.
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
LOG_EPS=LOG_EPSILON,
|
LOG_EPS=LOG_EPSILON,
|
||||||
@ -711,6 +712,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
streams[i].id,
|
||||||
streams[i].ground_truth.split(),
|
streams[i].ground_truth.split(),
|
||||||
sp.decode(streams[i].decoding_result()).split(),
|
sp.decode(streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
@ -731,6 +733,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
streams[i].id,
|
||||||
streams[i].ground_truth.split(),
|
streams[i].ground_truth.split(),
|
||||||
sp.decode(streams[i].decoding_result()).split(),
|
sp.decode(streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
|
@ -403,6 +403,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -415,9 +416,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -442,6 +443,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -624,6 +626,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -74,7 +74,6 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
from lhotse import CutSet
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
|||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
|
from lhotse import CutSet
|
||||||
from stream import Stream
|
from stream import Stream
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -678,6 +678,7 @@ def decode_dataset(
|
|||||||
# Each utterance has a Stream.
|
# Each utterance has a Stream.
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
LOG_EPS=LOG_EPSILON,
|
LOG_EPS=LOG_EPSILON,
|
||||||
@ -711,6 +712,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
streams[i].id,
|
||||||
streams[i].ground_truth.split(),
|
streams[i].ground_truth.split(),
|
||||||
sp.decode(streams[i].decoding_result()).split(),
|
sp.decode(streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
@ -731,6 +733,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
streams[i].id,
|
||||||
streams[i].ground_truth.split(),
|
streams[i].ground_truth.split(),
|
||||||
sp.decode(streams[i].decoding_result()).split(),
|
sp.decode(streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
|
@ -391,6 +391,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -403,9 +404,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -430,6 +431,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -612,6 +614,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -551,6 +551,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -564,9 +565,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -591,6 +592,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -631,6 +633,8 @@ def main():
|
|||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
@ -754,6 +758,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -28,6 +28,7 @@ class DecodeStream(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
cut_id: str,
|
||||||
initial_states: List[torch.Tensor],
|
initial_states: List[torch.Tensor],
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
@ -48,6 +49,7 @@ class DecodeStream(object):
|
|||||||
assert device == decoding_graph.device
|
assert device == decoding_graph.device
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.cut_id = cut_id
|
||||||
self.LOG_EPS = math.log(1e-10)
|
self.LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
self.states = initial_states
|
self.states = initial_states
|
||||||
@ -102,6 +104,10 @@ class DecodeStream(object):
|
|||||||
"""Return True if all the features are processed."""
|
"""Return True if all the features are processed."""
|
||||||
return self._done
|
return self._done
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
return self.cut_id
|
||||||
|
|
||||||
def set_features(
|
def set_features(
|
||||||
self,
|
self,
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
|
@ -356,6 +356,7 @@ def decode_dataset(
|
|||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
@ -385,6 +386,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
@ -402,6 +404,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ from scaling import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(EncoderInterface):
|
class Conformer(EncoderInterface):
|
||||||
@ -155,7 +155,7 @@ class Conformer(EncoderInterface):
|
|||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
@ -788,7 +788,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
if torch.jit.is_tracing():
|
if is_jit_tracing():
|
||||||
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
||||||
# It assumes that the maximum input won't have more than
|
# It assumes that the maximum input won't have more than
|
||||||
# 10k frames.
|
# 10k frames.
|
||||||
@ -1015,12 +1015,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
|
|
||||||
time2 = time1 + left_context
|
time2 = time1 + left_context
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert (
|
assert (
|
||||||
n == left_context + 2 * time1 - 1
|
n == left_context + 2 * time1 - 1
|
||||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||||
|
|
||||||
if torch.jit.is_tracing():
|
if is_jit_tracing():
|
||||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
cols = torch.arange(time2)
|
cols = torch.arange(time2)
|
||||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
@ -1111,12 +1111,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert embed_dim == embed_dim_to_check
|
assert embed_dim == embed_dim_to_check
|
||||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||||
|
|
||||||
head_dim = embed_dim // num_heads
|
head_dim = embed_dim // num_heads
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert (
|
assert (
|
||||||
head_dim * num_heads == embed_dim
|
head_dim * num_heads == embed_dim
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
@ -1232,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
src_len = k.size(0)
|
src_len = k.size(0)
|
||||||
|
|
||||||
if key_padding_mask is not None and not torch.jit.is_tracing():
|
if key_padding_mask is not None and not is_jit_tracing():
|
||||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||||
key_padding_mask.size(0), bsz
|
key_padding_mask.size(0), bsz
|
||||||
)
|
)
|
||||||
@ -1243,7 +1243,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||||
|
|
||||||
pos_emb_bsz = pos_emb.size(0)
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
|
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
@ -1280,7 +1280,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz * num_heads, tgt_len, -1
|
bsz * num_heads, tgt_len, -1
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert list(attn_output_weights.size()) == [
|
assert list(attn_output_weights.size()) == [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
tgt_len,
|
tgt_len,
|
||||||
@ -1345,7 +1345,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert list(attn_output.size()) == [
|
assert list(attn_output.size()) == [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
tgt_len,
|
tgt_len,
|
||||||
|
@ -574,6 +574,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -587,9 +588,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -614,6 +615,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -777,6 +779,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -14,13 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from scaling import ScaledConv1d, ScaledEmbedding
|
from scaling import ScaledConv1d, ScaledEmbedding
|
||||||
|
|
||||||
|
from icefall.utils import is_jit_tracing
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""This class modifies the stateless decoder from the following paper:
|
"""This class modifies the stateless decoder from the following paper:
|
||||||
@ -80,7 +80,10 @@ class Decoder(nn.Module):
|
|||||||
self.conv = nn.Identity()
|
self.conv = nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
need_pad: bool = True # Annotation should be Union[bool, torch.Tensor]
|
||||||
|
# but, torch.jit.script does not support Union.
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -108,7 +111,7 @@ class Decoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert embedding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
@ -18,6 +18,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
from icefall.utils import is_jit_tracing
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -52,7 +54,7 @@ class Joiner(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert encoder_out.ndim == decoder_out.ndim
|
assert encoder_out.ndim == decoder_out.ndim
|
||||||
assert encoder_out.ndim in (2, 4)
|
assert encoder_out.ndim in (2, 4)
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
@ -23,6 +23,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from icefall.utils import is_jit_tracing
|
||||||
|
|
||||||
|
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
def parse(x):
|
def parse(x):
|
||||||
@ -152,7 +154,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if not torch.jit.is_tracing():
|
if not is_jit_tracing():
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
scales = (
|
scales = (
|
||||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||||
@ -424,7 +426,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or is_jit_tracing():
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
@ -473,7 +475,7 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or is_jit_tracing():
|
||||||
return x * torch.sigmoid(x - 1.0)
|
return x * torch.sigmoid(x - 1.0)
|
||||||
else:
|
else:
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
@ -358,6 +358,7 @@ def decode_dataset(
|
|||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
@ -388,6 +389,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
@ -405,6 +407,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
|
@ -422,6 +422,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -434,9 +435,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -610,6 +611,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
asr_datamodule = AsrDataModule(args)
|
asr_datamodule = AsrDataModule(args)
|
||||||
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
|
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
|
||||||
|
|
||||||
|
@ -745,6 +745,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -760,9 +761,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -787,6 +788,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -1067,6 +1069,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
asr_datamodule = AsrDataModule(args)
|
asr_datamodule = AsrDataModule(args)
|
||||||
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ life by converting them to their non-scaled version during inference.
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -54,7 +55,10 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
|||||||
in_features=scaled_linear.in_features,
|
in_features=scaled_linear.in_features,
|
||||||
out_features=scaled_linear.out_features,
|
out_features=scaled_linear.out_features,
|
||||||
bias=True, # otherwise, it throws errors when converting to PNNX format.
|
bias=True, # otherwise, it throws errors when converting to PNNX format.
|
||||||
device=weight.device,
|
# device=weight.device, # Pytorch version before v1.9.0 does not has
|
||||||
|
# this argument. Comment out for now, we will
|
||||||
|
# see if it will raise error for versions
|
||||||
|
# after v1.9.0
|
||||||
)
|
)
|
||||||
linear.weight.data.copy_(weight)
|
linear.weight.data.copy_(weight)
|
||||||
|
|
||||||
@ -164,6 +168,24 @@ def scaled_embedding_to_embedding(
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule
|
||||||
|
# get_submodule was added to nn.Module at v1.9.0
|
||||||
|
def get_submodule(model, target):
|
||||||
|
if target == "":
|
||||||
|
return model
|
||||||
|
atoms: List[str] = target.split(".")
|
||||||
|
mod: torch.nn.Module = model
|
||||||
|
for item in atoms:
|
||||||
|
if not hasattr(mod, item):
|
||||||
|
raise AttributeError(
|
||||||
|
mod._get_name() + " has no " "attribute `" + item + "`"
|
||||||
|
)
|
||||||
|
mod = getattr(mod, item)
|
||||||
|
if not isinstance(mod, torch.nn.Module):
|
||||||
|
raise AttributeError("`" + item + "` is not " "an nn.Module")
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||||
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
||||||
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
||||||
@ -200,7 +222,7 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
|||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
parent, child = k.rsplit(".", maxsplit=1)
|
parent, child = k.rsplit(".", maxsplit=1)
|
||||||
setattr(model.get_submodule(parent), child, v)
|
setattr(get_submodule(model, parent), child, v)
|
||||||
else:
|
else:
|
||||||
setattr(model, k, v)
|
setattr(model, k, v)
|
||||||
|
|
||||||
|
@ -359,6 +359,7 @@ def decode_dataset(
|
|||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
@ -389,6 +390,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
@ -406,6 +408,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
|
@ -578,6 +578,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -591,9 +592,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -618,6 +619,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -831,6 +833,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -371,6 +371,7 @@ def decode_dataset(
|
|||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
@ -401,6 +402,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
@ -418,6 +420,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
|
@ -564,6 +564,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -577,9 +578,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -604,6 +605,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -817,6 +819,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -371,6 +371,7 @@ def decode_dataset(
|
|||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
@ -401,6 +402,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
@ -418,6 +420,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
|
@ -387,6 +387,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -399,9 +400,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -426,6 +427,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -608,6 +610,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -311,6 +311,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -324,9 +325,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -349,6 +350,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -473,6 +475,8 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -295,6 +295,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -306,9 +307,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -333,6 +334,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -424,6 +426,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -292,6 +292,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -303,9 +304,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -330,6 +331,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -422,6 +424,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -350,6 +350,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -362,9 +363,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -389,6 +390,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -500,6 +502,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -350,6 +350,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -362,9 +363,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -389,6 +390,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -500,6 +502,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
@ -351,6 +351,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -363,9 +364,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -390,6 +391,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -503,6 +505,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
asr_datamodule = AsrDataModule(args)
|
asr_datamodule = AsrDataModule(args)
|
||||||
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
||||||
|
|
||||||
|
@ -365,6 +365,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -377,9 +378,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -405,6 +406,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -561,6 +563,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
spgispeech = SPGISpeechAsrDataModule(args)
|
spgispeech = SPGISpeechAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = spgispeech.dev_cuts()
|
dev_cuts = spgispeech.dev_cuts()
|
||||||
|
@ -453,6 +453,7 @@ def decode_dataset(
|
|||||||
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
|
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
zh_texts = []
|
zh_texts = []
|
||||||
en_texts = []
|
en_texts = []
|
||||||
for i in range(len(texts)):
|
for i in range(len(texts)):
|
||||||
@ -487,14 +488,14 @@ def decode_dataset(
|
|||||||
# print(hyps_texts)
|
# print(hyps_texts)
|
||||||
hyps, zh_hyps, en_hyps = hyps_texts
|
hyps, zh_hyps, en_hyps = hyps_texts
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
for hyp_words, ref_text in zip(zh_hyps, zh_texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, zh_hyps, zh_texts):
|
||||||
this_batch_zh.append((ref_text, hyp_words))
|
this_batch_zh.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
for hyp_words, ref_text in zip(en_hyps, en_texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, en_hyps, en_texts):
|
||||||
this_batch_en.append((ref_text, hyp_words))
|
this_batch_en.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
zh_results[name + "_zh"].extend(this_batch_zh)
|
zh_results[name + "_zh"].extend(this_batch_zh)
|
||||||
@ -521,6 +522,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -710,6 +712,8 @@ def main():
|
|||||||
c.supervisions[0].text = text_normalize(text)
|
c.supervisions[0].text = text_normalize(text)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
tal_csasr = TAL_CSASRAsrDataModule(args)
|
tal_csasr = TAL_CSASRAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = tal_csasr.valid_cuts()
|
dev_cuts = tal_csasr.valid_cuts()
|
||||||
|
@ -350,6 +350,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -362,9 +363,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -389,6 +390,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -498,6 +500,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
tedlium = TedLiumAsrDataModule(args)
|
tedlium = TedLiumAsrDataModule(args)
|
||||||
dev_cuts = tedlium.dev_cuts()
|
dev_cuts = tedlium.dev_cuts()
|
||||||
test_cuts = tedlium.test_cuts()
|
test_cuts = tedlium.test_cuts()
|
||||||
|
@ -325,6 +325,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -336,9 +337,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -363,6 +364,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -462,6 +464,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
tedlium = TedLiumAsrDataModule(args)
|
tedlium = TedLiumAsrDataModule(args)
|
||||||
dev_cuts = tedlium.dev_cuts()
|
dev_cuts = tedlium.dev_cuts()
|
||||||
test_cuts = tedlium.test_cuts()
|
test_cuts = tedlium.test_cuts()
|
||||||
|
@ -311,6 +311,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -324,9 +325,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -349,6 +350,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -468,6 +470,8 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
timit = TimitAsrDataModule(args)
|
timit = TimitAsrDataModule(args)
|
||||||
test_set = "TEST"
|
test_set = "TEST"
|
||||||
test_dl = timit.test_dataloaders()
|
test_dl = timit.test_dataloaders()
|
||||||
|
@ -310,6 +310,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -323,9 +324,9 @@ def decode_dataset(
|
|||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
@ -348,6 +349,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -467,6 +469,8 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
timit = TimitAsrDataModule(args)
|
timit = TimitAsrDataModule(args)
|
||||||
test_set = "TEST"
|
test_set = "TEST"
|
||||||
test_dl = timit.test_dataloaders()
|
test_dl = timit.test_dataloaders()
|
||||||
|
@ -491,6 +491,7 @@ def decode_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [list(str(text)) for text in texts]
|
texts = [list(str(text)) for text in texts]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -504,8 +505,8 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -530,6 +531,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -678,6 +680,8 @@ def main():
|
|||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from lhotse.dataset.webdataset import export_to_webdataset
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
dev = "dev"
|
dev = "dev"
|
||||||
|
@ -461,6 +461,7 @@ def decode_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [list(str(text)) for text in texts]
|
texts = [list(str(text)) for text in texts]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -473,8 +474,8 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -499,6 +500,7 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -682,6 +684,8 @@ def main():
|
|||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from lhotse.dataset.webdataset import export_to_webdataset
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
dev = "dev"
|
dev = "dev"
|
||||||
|
@ -28,6 +28,7 @@ class DecodeStream(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
cut_id: str,
|
||||||
initial_states: List[torch.Tensor],
|
initial_states: List[torch.Tensor],
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
@ -48,6 +49,7 @@ class DecodeStream(object):
|
|||||||
assert device == decoding_graph.device
|
assert device == decoding_graph.device
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.cut_id = cut_id
|
||||||
self.LOG_EPS = math.log(1e-10)
|
self.LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
self.states = initial_states
|
self.states = initial_states
|
||||||
@ -102,6 +104,10 @@ class DecodeStream(object):
|
|||||||
"""Return True if all the features are processed."""
|
"""Return True if all the features are processed."""
|
||||||
return self._done
|
return self._done
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
return self.cut_id
|
||||||
|
|
||||||
def set_features(
|
def set_features(
|
||||||
self,
|
self,
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
|
@ -396,6 +396,7 @@ def decode_dataset(
|
|||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
device=device,
|
device=device,
|
||||||
@ -423,6 +424,7 @@ def decode_dataset(
|
|||||||
hyp = decode_streams[i].decoding_result()
|
hyp = decode_streams[i].decoding_result()
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
list(decode_streams[i].ground_truth),
|
list(decode_streams[i].ground_truth),
|
||||||
[lexicon.token_table[idx] for idx in hyp],
|
[lexicon.token_table[idx] for idx in hyp],
|
||||||
)
|
)
|
||||||
@ -441,6 +443,7 @@ def decode_dataset(
|
|||||||
hyp = decode_streams[i].decoding_result()
|
hyp = decode_streams[i].decoding_result()
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
list(decode_streams[i].ground_truth),
|
list(decode_streams[i].ground_truth),
|
||||||
[lexicon.token_table[idx] for idx in hyp],
|
[lexicon.token_table[idx] for idx in hyp],
|
||||||
)
|
)
|
||||||
|
@ -178,6 +178,7 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps = decode_one_batch(
|
hyps = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -189,9 +190,9 @@ def decode_dataset(
|
|||||||
|
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results.extend(this_batch)
|
results.extend(this_batch)
|
||||||
|
|
||||||
@ -237,6 +238,7 @@ def save_results(
|
|||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
|
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -303,6 +305,8 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
yes_no = YesNoAsrDataModule(args)
|
yes_no = YesNoAsrDataModule(args)
|
||||||
test_dl = yes_no.test_dataloaders()
|
test_dl = yes_no.test_dataloaders()
|
||||||
results = decode_dataset(
|
results = decode_dataset(
|
||||||
|
@ -165,6 +165,7 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps = decode_one_batch(
|
hyps = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -174,9 +175,9 @@ def decode_dataset(
|
|||||||
|
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results.extend(this_batch)
|
results.extend(this_batch)
|
||||||
|
|
||||||
@ -222,6 +223,7 @@ def save_results(
|
|||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
|
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -291,6 +293,8 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
yes_no = YesNoAsrDataModule(args)
|
yes_no = YesNoAsrDataModule(args)
|
||||||
test_dl = yes_no.test_dataloaders()
|
test_dl = yes_no.test_dataloaders()
|
||||||
results = decode_dataset(
|
results = decode_dataset(
|
||||||
|
@ -49,6 +49,7 @@ from .utils import (
|
|||||||
get_alignments,
|
get_alignments,
|
||||||
get_executor,
|
get_executor,
|
||||||
get_texts,
|
get_texts,
|
||||||
|
is_jit_tracing,
|
||||||
l1_norm,
|
l1_norm,
|
||||||
l2_norm,
|
l2_norm,
|
||||||
linf_norm,
|
linf_norm,
|
||||||
|
@ -42,6 +42,18 @@ from icefall.checkpoint import average_checkpoints
|
|||||||
Pathlike = Union[str, Path]
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
|
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
|
||||||
|
# Fixed: https://github.com/pytorch/pytorch/pull/49853
|
||||||
|
# The fix was included in v1.9.0
|
||||||
|
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
|
||||||
|
def is_jit_tracing():
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
return False
|
||||||
|
elif torch.jit.is_tracing():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_executor():
|
def get_executor():
|
||||||
# We'll either return a process pool or a distributed worker pool.
|
# We'll either return a process pool or a distributed worker pool.
|
||||||
@ -321,7 +333,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
|||||||
|
|
||||||
|
|
||||||
def store_transcripts(
|
def store_transcripts(
|
||||||
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save predicted results and reference transcripts to a file.
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
@ -329,15 +341,15 @@ def store_transcripts(
|
|||||||
filename:
|
filename:
|
||||||
File to save the results to.
|
File to save the results to.
|
||||||
texts:
|
texts:
|
||||||
An iterable of tuples. The first element is the reference transcript
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
while the second element is the predicted result.
|
the reference transcript and the third element is the predicted result.
|
||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
for ref, hyp in texts:
|
for cut_id, ref, hyp in texts:
|
||||||
print(f"ref={ref}", file=f)
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
print(f"hyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
@ -372,8 +384,8 @@ def write_error_stats(
|
|||||||
The reference word `SIR` is missing in the predicted
|
The reference word `SIR` is missing in the predicted
|
||||||
results (a deletion error).
|
results (a deletion error).
|
||||||
results:
|
results:
|
||||||
An iterable of tuples. The first element is the reference transcript
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
while the second element is the predicted result.
|
the reference transcript and the third element is the predicted result.
|
||||||
enable_log:
|
enable_log:
|
||||||
If True, also print detailed WER to the console.
|
If True, also print detailed WER to the console.
|
||||||
Otherwise, it is written only to the given file.
|
Otherwise, it is written only to the given file.
|
||||||
@ -389,7 +401,7 @@ def write_error_stats(
|
|||||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
num_corr = 0
|
num_corr = 0
|
||||||
ERR = "*"
|
ERR = "*"
|
||||||
for ref, hyp in results:
|
for cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR)
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
for ref_word, hyp_word in ali:
|
for ref_word, hyp_word in ali:
|
||||||
if ref_word == ERR:
|
if ref_word == ERR:
|
||||||
@ -405,7 +417,7 @@ def write_error_stats(
|
|||||||
else:
|
else:
|
||||||
words[ref_word][0] += 1
|
words[ref_word][0] += 1
|
||||||
num_corr += 1
|
num_corr += 1
|
||||||
ref_len = sum([len(r) for r, _ in results])
|
ref_len = sum([len(r) for _, r, _ in results])
|
||||||
sub_errs = sum(subs.values())
|
sub_errs = sum(subs.values())
|
||||||
ins_errs = sum(ins.values())
|
ins_errs = sum(ins.values())
|
||||||
del_errs = sum(dels.values())
|
del_errs = sum(dels.values())
|
||||||
@ -434,7 +446,7 @@ def write_error_stats(
|
|||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
for ref, hyp in results:
|
for cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR)
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
combine_successive_errors = True
|
combine_successive_errors = True
|
||||||
if combine_successive_errors:
|
if combine_successive_errors:
|
||||||
@ -461,7 +473,8 @@ def write_error_stats(
|
|||||||
]
|
]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
" ".join(
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
(
|
(
|
||||||
ref_word
|
ref_word
|
||||||
if ref_word == hyp_word
|
if ref_word == hyp_word
|
||||||
|
Loading…
x
Reference in New Issue
Block a user