Sort results to make it more convenient to compare decoding results (#522)

* Sort result to make it more convenient to compare decoding results

* Add cut_id to recognition results

* add cut_id to results for all recipes

* Fix torch.jit.script

* Fix comments

* Minor fixes

* Fix torch.jit.tracing for Pytorch version before v1.9.0
This commit is contained in:
Wei Kang 2022-08-12 07:12:50 +08:00 committed by GitHub
parent 5149788cb2
commit 5c17255eec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 379 additions and 126 deletions

View File

@ -367,6 +367,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts] texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -379,8 +380,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -405,6 +406,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -528,6 +530,8 @@ def main():
from lhotse import CutSet from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
dev = "dev" dev = "dev"

View File

@ -374,6 +374,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -389,9 +390,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -419,6 +420,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -537,6 +539,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts() test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts) test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -386,6 +386,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -401,9 +402,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -431,6 +432,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -556,6 +558,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts() test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts) test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -377,6 +377,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -389,9 +390,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -416,6 +417,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -606,6 +608,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args) asr_datamodule = AsrDataModule(args)
aishell = AIShell(manifest_dir=args.manifest_dir) aishell = AIShell(manifest_dir=args.manifest_dir)
test_cuts = aishell.test_cuts() test_cuts = aishell.test_cuts()

View File

@ -241,6 +241,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -253,9 +254,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -278,6 +279,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -365,6 +367,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts() test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts) test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -38,8 +38,8 @@ from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats,
str2bool, str2bool,
write_error_stats,
) )
@ -296,6 +296,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -307,9 +308,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -334,6 +335,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -438,6 +440,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts() test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts) test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -341,6 +341,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -353,9 +354,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -380,6 +381,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -496,6 +498,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args) asr_datamodule = AsrDataModule(args)
aishell = AIShell(manifest_dir=args.manifest_dir) aishell = AIShell(manifest_dir=args.manifest_dir)
test_cuts = aishell.test_cuts() test_cuts = aishell.test_cuts()

View File

@ -345,6 +345,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -357,9 +358,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -384,6 +385,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -498,6 +500,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts() test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts) test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -514,6 +514,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -527,8 +528,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -553,6 +554,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -756,6 +758,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell2 = AiShell2AsrDataModule(args) aishell2 = AiShell2AsrDataModule(args)
valid_cuts = aishell2.valid_cuts() valid_cuts = aishell2.valid_cuts()

View File

@ -378,6 +378,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts] texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -390,8 +391,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -416,6 +417,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -607,6 +609,8 @@ def main():
c.supervisions[0].text = text_normalize(text) c.supervisions[0].text = text_normalize(text)
return c return c
# we need cut ids to display recognition results.
args.return_cuts = True
aishell4 = Aishell4AsrDataModule(args) aishell4 = Aishell4AsrDataModule(args)
test_cuts = aishell4.test_cuts() test_cuts = aishell4.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut) test_cuts = test_cuts.map(text_normalize_for_cut)

View File

@ -367,6 +367,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts] texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -379,8 +380,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -405,6 +406,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -535,6 +537,8 @@ def main():
from lhotse import CutSet from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
alimeeting = AlimeetingAsrDataModule(args) alimeeting = AlimeetingAsrDataModule(args)
dev = "eval" dev = "eval"

View File

@ -451,6 +451,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -469,9 +470,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
else: else:
@ -512,6 +513,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = post_processing(results) results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -676,6 +678,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)
dev_cuts = gigaspeech.dev_cuts() dev_cuts = gigaspeech.dev_cuts()

View File

@ -374,6 +374,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -386,9 +387,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -414,6 +415,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = post_processing(results) results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -544,6 +546,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)
dev_cuts = gigaspeech.dev_cuts() dev_cuts = gigaspeech.dev_cuts()

View File

@ -525,6 +525,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -544,9 +545,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
else: else:
@ -586,6 +587,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -779,6 +781,8 @@ def main():
) )
rnn_lm_model.eval() rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -31,14 +31,13 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
@ -633,6 +632,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -652,9 +652,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
else: else:
@ -694,6 +694,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -956,6 +957,8 @@ def main():
) )
rnn_lm_model.eval() rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -449,6 +449,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -466,9 +467,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -496,6 +497,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -661,6 +663,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only. # CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip # If you want to skip test-clean, you have to skip

View File

@ -403,6 +403,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -415,9 +416,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -442,6 +443,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -624,6 +626,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -29,6 +29,7 @@ class Stream(object):
def __init__( def __init__(
self, self,
params: AttributeDict, params: AttributeDict,
cut_id: str,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
LOG_EPS: float = math.log(1e-10), LOG_EPS: float = math.log(1e-10),
@ -44,6 +45,7 @@ class Stream(object):
The device to run this stream. The device to run this stream.
""" """
self.LOG_EPS = LOG_EPS self.LOG_EPS = LOG_EPS
self.cut_id = cut_id
# Containing attention caches and convolution caches # Containing attention caches and convolution caches
self.states: Optional[ self.states: Optional[
@ -138,6 +140,10 @@ class Stream(object):
"""Return True if all feature frames are processed.""" """Return True if all feature frames are processed."""
return self._done return self._done
@property
def id(self) -> str:
return self.cut_id
def decoding_result(self) -> List[int]: def decoding_result(self) -> List[int]:
"""Obtain current decoding result.""" """Obtain current decoding result."""
if self.decoding_method == "greedy_search": if self.decoding_method == "greedy_search":

View File

@ -74,7 +74,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
from lhotse import CutSet
import numpy as np import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states from emformer import LOG_EPSILON, stack_states, unstack_states
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from stream import Stream from stream import Stream
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -678,6 +678,7 @@ def decode_dataset(
# Each utterance has a Stream. # Each utterance has a Stream.
stream = Stream( stream = Stream(
params=params, params=params,
cut_id=cut.id,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
LOG_EPS=LOG_EPSILON, LOG_EPS=LOG_EPSILON,
@ -711,6 +712,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
streams[i].id,
streams[i].ground_truth.split(), streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(), sp.decode(streams[i].decoding_result()).split(),
) )
@ -731,6 +733,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
streams[i].id,
streams[i].ground_truth.split(), streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(), sp.decode(streams[i].decoding_result()).split(),
) )

View File

@ -403,6 +403,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -415,9 +416,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -442,6 +443,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -624,6 +626,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -74,7 +74,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
from lhotse import CutSet
import numpy as np import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states from emformer import LOG_EPSILON, stack_states, unstack_states
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from stream import Stream from stream import Stream
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -678,6 +678,7 @@ def decode_dataset(
# Each utterance has a Stream. # Each utterance has a Stream.
stream = Stream( stream = Stream(
params=params, params=params,
cut_id=cut.id,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
LOG_EPS=LOG_EPSILON, LOG_EPS=LOG_EPSILON,
@ -711,6 +712,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
streams[i].id,
streams[i].ground_truth.split(), streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(), sp.decode(streams[i].decoding_result()).split(),
) )
@ -731,6 +733,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
streams[i].id,
streams[i].ground_truth.split(), streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(), sp.decode(streams[i].decoding_result()).split(),
) )

View File

@ -391,6 +391,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -403,9 +404,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -430,6 +431,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -612,6 +614,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -551,6 +551,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -564,9 +565,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -591,6 +592,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -631,6 +633,8 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
# we need cut ids to display recognition results.
args.return_cuts = True
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -754,6 +758,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -28,6 +28,7 @@ class DecodeStream(object):
def __init__( def __init__(
self, self,
params: AttributeDict, params: AttributeDict,
cut_id: str,
initial_states: List[torch.Tensor], initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
@ -48,6 +49,7 @@ class DecodeStream(object):
assert device == decoding_graph.device assert device == decoding_graph.device
self.params = params self.params = params
self.cut_id = cut_id
self.LOG_EPS = math.log(1e-10) self.LOG_EPS = math.log(1e-10)
self.states = initial_states self.states = initial_states
@ -102,6 +104,10 @@ class DecodeStream(object):
"""Return True if all the features are processed.""" """Return True if all the features are processed."""
return self._done return self._done
@property
def id(self) -> str:
return self.cut_id
def set_features( def set_features(
self, self,
features: torch.Tensor, features: torch.Tensor,

View File

@ -356,6 +356,7 @@ def decode_dataset(
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id,
initial_states=initial_states, initial_states=initial_states,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
@ -385,6 +386,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
@ -402,6 +404,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )

View File

@ -32,7 +32,7 @@ from scaling import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
class Conformer(EncoderInterface): class Conformer(EncoderInterface):
@ -155,7 +155,7 @@ class Conformer(EncoderInterface):
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1 lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
@ -788,7 +788,7 @@ class RelPositionalEncoding(torch.nn.Module):
) -> None: ) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
if torch.jit.is_tracing(): if is_jit_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than # It assumes that the maximum input won't have more than
# 10k frames. # 10k frames.
@ -1015,12 +1015,12 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape (batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context time2 = time1 + left_context
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert ( assert (
n == left_context + 2 * time1 - 1 n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1" ), f"{n} == {left_context} + 2 * {time1} - 1"
if torch.jit.is_tracing(): if is_jit_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1) rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2) cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
@ -1111,12 +1111,12 @@ class RelPositionMultiheadAttention(nn.Module):
""" """
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert embed_dim == embed_dim_to_check assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1) assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert ( assert (
head_dim * num_heads == embed_dim head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
@ -1232,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module):
src_len = k.size(0) src_len = k.size(0)
if key_padding_mask is not None and not torch.jit.is_tracing(): if key_padding_mask is not None and not is_jit_tracing():
assert key_padding_mask.size(0) == bsz, "{} == {}".format( assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz key_padding_mask.size(0), bsz
) )
@ -1243,7 +1243,7 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.transpose(0, 1) # (batch, time1, head, d_k) q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0) pos_emb_bsz = pos_emb.size(0)
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert pos_emb_bsz in (1, bsz) # actually it is 1 assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
@ -1280,7 +1280,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz * num_heads, tgt_len, -1 bsz * num_heads, tgt_len, -1
) )
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
tgt_len, tgt_len,
@ -1345,7 +1345,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert list(attn_output.size()) == [ assert list(attn_output.size()) == [
bsz * num_heads, bsz * num_heads,
tgt_len, tgt_len,

View File

@ -574,6 +574,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -587,9 +588,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -614,6 +615,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -777,6 +779,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from scaling import ScaledConv1d, ScaledEmbedding from scaling import ScaledConv1d, ScaledEmbedding
from icefall.utils import is_jit_tracing
class Decoder(nn.Module): class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper: """This class modifies the stateless decoder from the following paper:
@ -80,7 +80,10 @@ class Decoder(nn.Module):
self.conv = nn.Identity() self.conv = nn.Identity()
def forward( def forward(
self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True self,
y: torch.Tensor,
need_pad: bool = True # Annotation should be Union[bool, torch.Tensor]
# but, torch.jit.script does not support Union.
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -108,7 +111,7 @@ class Decoder(nn.Module):
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert embedding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out) embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)

View File

@ -18,6 +18,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import is_jit_tracing
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__( def __init__(
@ -52,7 +54,7 @@ class Joiner(nn.Module):
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert encoder_out.ndim == decoder_out.ndim assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4) assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape assert encoder_out.shape == decoder_out.shape

View File

@ -23,6 +23,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from icefall.utils import is_jit_tracing
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
@ -152,7 +154,7 @@ class BasicNorm(torch.nn.Module):
self.register_buffer("eps", torch.tensor(eps).log().detach()) self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if not torch.jit.is_tracing(): if not is_jit_tracing():
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
scales = ( scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
@ -424,7 +426,7 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or is_jit_tracing():
return x return x
else: else:
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
@ -473,7 +475,7 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1). that we approximate closely with x * sigmoid(x-1).
""" """
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or is_jit_tracing():
return x * torch.sigmoid(x - 1.0) return x * torch.sigmoid(x - 1.0)
else: else:
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)

View File

@ -358,6 +358,7 @@ def decode_dataset(
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id,
initial_states=initial_states, initial_states=initial_states,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
@ -388,6 +389,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
@ -405,6 +407,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )

View File

@ -422,6 +422,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -434,9 +435,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -610,6 +611,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args) asr_datamodule = AsrDataModule(args)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)

View File

@ -745,6 +745,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -760,9 +761,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -787,6 +788,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -1067,6 +1069,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args) asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)

View File

@ -25,6 +25,7 @@ life by converting them to their non-scaled version during inference.
import copy import copy
import re import re
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -54,7 +55,10 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
in_features=scaled_linear.in_features, in_features=scaled_linear.in_features,
out_features=scaled_linear.out_features, out_features=scaled_linear.out_features,
bias=True, # otherwise, it throws errors when converting to PNNX format. bias=True, # otherwise, it throws errors when converting to PNNX format.
device=weight.device, # device=weight.device, # Pytorch version before v1.9.0 does not has
# this argument. Comment out for now, we will
# see if it will raise error for versions
# after v1.9.0
) )
linear.weight.data.copy_(weight) linear.weight.data.copy_(weight)
@ -164,6 +168,24 @@ def scaled_embedding_to_embedding(
return embedding return embedding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
@ -200,7 +222,7 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
for k, v in d.items(): for k, v in d.items():
if "." in k: if "." in k:
parent, child = k.rsplit(".", maxsplit=1) parent, child = k.rsplit(".", maxsplit=1)
setattr(model.get_submodule(parent), child, v) setattr(get_submodule(model, parent), child, v)
else: else:
setattr(model, k, v) setattr(model, k, v)

View File

@ -359,6 +359,7 @@ def decode_dataset(
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id,
initial_states=initial_states, initial_states=initial_states,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
@ -389,6 +390,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
@ -406,6 +408,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )

View File

@ -578,6 +578,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -591,9 +592,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -618,6 +619,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -831,6 +833,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -371,6 +371,7 @@ def decode_dataset(
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id,
initial_states=initial_states, initial_states=initial_states,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
@ -401,6 +402,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
@ -418,6 +420,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )

View File

@ -564,6 +564,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -577,9 +578,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -604,6 +605,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -817,6 +819,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -371,6 +371,7 @@ def decode_dataset(
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id,
initial_states=initial_states, initial_states=initial_states,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
@ -401,6 +402,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
@ -418,6 +420,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )

View File

@ -387,6 +387,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -399,9 +400,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -426,6 +427,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -608,6 +610,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -311,6 +311,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -324,9 +325,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -349,6 +350,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -473,6 +475,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -295,6 +295,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -306,9 +307,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -333,6 +334,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -424,6 +426,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -292,6 +292,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -303,9 +304,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -330,6 +331,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -422,6 +424,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -350,6 +350,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -362,9 +363,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -389,6 +390,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -500,6 +502,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -350,6 +350,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -362,9 +363,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -389,6 +390,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -500,6 +502,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -351,6 +351,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -363,9 +364,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -390,6 +391,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -503,6 +505,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args) asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)

View File

@ -365,6 +365,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -377,9 +378,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -405,6 +406,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -561,6 +563,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
spgispeech = SPGISpeechAsrDataModule(args) spgispeech = SPGISpeechAsrDataModule(args)
dev_cuts = spgispeech.dev_cuts() dev_cuts = spgispeech.dev_cuts()

View File

@ -453,6 +453,7 @@ def decode_dataset(
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
zh_texts = [] zh_texts = []
en_texts = [] en_texts = []
for i in range(len(texts)): for i in range(len(texts)):
@ -487,14 +488,14 @@ def decode_dataset(
# print(hyps_texts) # print(hyps_texts)
hyps, zh_hyps, en_hyps = hyps_texts hyps, zh_hyps, en_hyps = hyps_texts
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
for hyp_words, ref_text in zip(zh_hyps, zh_texts): for cut_id, hyp_words, ref_text in zip(cut_ids, zh_hyps, zh_texts):
this_batch_zh.append((ref_text, hyp_words)) this_batch_zh.append((cut_id, ref_text, hyp_words))
for hyp_words, ref_text in zip(en_hyps, en_texts): for cut_id, hyp_words, ref_text in zip(cut_ids, en_hyps, en_texts):
this_batch_en.append((ref_text, hyp_words)) this_batch_en.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
zh_results[name + "_zh"].extend(this_batch_zh) zh_results[name + "_zh"].extend(this_batch_zh)
@ -521,6 +522,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -710,6 +712,8 @@ def main():
c.supervisions[0].text = text_normalize(text) c.supervisions[0].text = text_normalize(text)
return c return c
# we need cut ids to display recognition results.
args.return_cuts = True
tal_csasr = TAL_CSASRAsrDataModule(args) tal_csasr = TAL_CSASRAsrDataModule(args)
dev_cuts = tal_csasr.valid_cuts() dev_cuts = tal_csasr.valid_cuts()

View File

@ -350,6 +350,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -362,9 +363,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -389,6 +390,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -498,6 +500,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
tedlium = TedLiumAsrDataModule(args) tedlium = TedLiumAsrDataModule(args)
dev_cuts = tedlium.dev_cuts() dev_cuts = tedlium.dev_cuts()
test_cuts = tedlium.test_cuts() test_cuts = tedlium.test_cuts()

View File

@ -325,6 +325,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -336,9 +337,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -363,6 +364,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -462,6 +464,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
tedlium = TedLiumAsrDataModule(args) tedlium = TedLiumAsrDataModule(args)
dev_cuts = tedlium.dev_cuts() dev_cuts = tedlium.dev_cuts()
test_cuts = tedlium.test_cuts() test_cuts = tedlium.test_cuts()

View File

@ -311,6 +311,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -324,9 +325,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -349,6 +350,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -468,6 +470,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
timit = TimitAsrDataModule(args) timit = TimitAsrDataModule(args)
test_set = "TEST" test_set = "TEST"
test_dl = timit.test_dataloaders() test_dl = timit.test_dataloaders()

View File

@ -310,6 +310,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -323,9 +324,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -348,6 +349,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -467,6 +469,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
timit = TimitAsrDataModule(args) timit = TimitAsrDataModule(args)
test_set = "TEST" test_set = "TEST"
test_dl = timit.test_dataloaders() test_dl = timit.test_dataloaders()

View File

@ -491,6 +491,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts] texts = [list(str(text)) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -504,8 +505,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -530,6 +531,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -678,6 +680,8 @@ def main():
from lhotse import CutSet from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev" dev = "dev"

View File

@ -461,6 +461,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts] texts = [list(str(text)) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -473,8 +474,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -499,6 +500,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -682,6 +684,8 @@ def main():
from lhotse import CutSet from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev" dev = "dev"

View File

@ -28,6 +28,7 @@ class DecodeStream(object):
def __init__( def __init__(
self, self,
params: AttributeDict, params: AttributeDict,
cut_id: str,
initial_states: List[torch.Tensor], initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
@ -48,6 +49,7 @@ class DecodeStream(object):
assert device == decoding_graph.device assert device == decoding_graph.device
self.params = params self.params = params
self.cut_id = cut_id
self.LOG_EPS = math.log(1e-10) self.LOG_EPS = math.log(1e-10)
self.states = initial_states self.states = initial_states
@ -102,6 +104,10 @@ class DecodeStream(object):
"""Return True if all the features are processed.""" """Return True if all the features are processed."""
return self._done return self._done
@property
def id(self) -> str:
return self.cut_id
def set_features( def set_features(
self, self,
features: torch.Tensor, features: torch.Tensor,

View File

@ -396,6 +396,7 @@ def decode_dataset(
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id,
initial_states=initial_states, initial_states=initial_states,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
device=device, device=device,
@ -423,6 +424,7 @@ def decode_dataset(
hyp = decode_streams[i].decoding_result() hyp = decode_streams[i].decoding_result()
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
list(decode_streams[i].ground_truth), list(decode_streams[i].ground_truth),
[lexicon.token_table[idx] for idx in hyp], [lexicon.token_table[idx] for idx in hyp],
) )
@ -441,6 +443,7 @@ def decode_dataset(
hyp = decode_streams[i].decoding_result() hyp = decode_streams[i].decoding_result()
decode_results.append( decode_results.append(
( (
decode_streams[i].id,
list(decode_streams[i].ground_truth), list(decode_streams[i].ground_truth),
[lexicon.token_table[idx] for idx in hyp], [lexicon.token_table[idx] for idx in hyp],
) )

View File

@ -178,6 +178,7 @@ def decode_dataset(
results = [] results = []
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch( hyps = decode_one_batch(
params=params, params=params,
@ -189,9 +190,9 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch) results.extend(this_batch)
@ -237,6 +238,7 @@ def save_results(
Return None. Return None.
""" """
recog_path = exp_dir / f"recogs-{test_set_name}.txt" recog_path = exp_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -303,6 +305,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
yes_no = YesNoAsrDataModule(args) yes_no = YesNoAsrDataModule(args)
test_dl = yes_no.test_dataloaders() test_dl = yes_no.test_dataloaders()
results = decode_dataset( results = decode_dataset(

View File

@ -165,6 +165,7 @@ def decode_dataset(
results = [] results = []
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch( hyps = decode_one_batch(
params=params, params=params,
@ -174,9 +175,9 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch) results.extend(this_batch)
@ -222,6 +223,7 @@ def save_results(
Return None. Return None.
""" """
recog_path = exp_dir / f"recogs-{test_set_name}.txt" recog_path = exp_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -291,6 +293,8 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
# we need cut ids to display recognition results.
args.return_cuts = True
yes_no = YesNoAsrDataModule(args) yes_no = YesNoAsrDataModule(args)
test_dl = yes_no.test_dataloaders() test_dl = yes_no.test_dataloaders()
results = decode_dataset( results = decode_dataset(

View File

@ -49,6 +49,7 @@ from .utils import (
get_alignments, get_alignments,
get_executor, get_executor,
get_texts, get_texts,
is_jit_tracing,
l1_norm, l1_norm,
l2_norm, l2_norm,
linf_norm, linf_norm,

View File

@ -42,6 +42,18 @@ from icefall.checkpoint import average_checkpoints
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
# Fixed: https://github.com/pytorch/pytorch/pull/49853
# The fix was included in v1.9.0
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
def is_jit_tracing():
if torch.jit.is_scripting():
return False
elif torch.jit.is_tracing():
return True
return False
@contextmanager @contextmanager
def get_executor(): def get_executor():
# We'll either return a process pool or a distributed worker pool. # We'll either return a process pool or a distributed worker pool.
@ -321,7 +333,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
def store_transcripts( def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str]] filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
) -> None: ) -> None:
"""Save predicted results and reference transcripts to a file. """Save predicted results and reference transcripts to a file.
@ -329,15 +341,15 @@ def store_transcripts(
filename: filename:
File to save the results to. File to save the results to.
texts: texts:
An iterable of tuples. The first element is the reference transcript An iterable of tuples. The first element is the cur_id, the second is
while the second element is the predicted result. the reference transcript and the third element is the predicted result.
Returns: Returns:
Return None. Return None.
""" """
with open(filename, "w") as f: with open(filename, "w") as f:
for ref, hyp in texts: for cut_id, ref, hyp in texts:
print(f"ref={ref}", file=f) print(f"{cut_id}:\tref={ref}", file=f)
print(f"hyp={hyp}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats( def write_error_stats(
@ -372,8 +384,8 @@ def write_error_stats(
The reference word `SIR` is missing in the predicted The reference word `SIR` is missing in the predicted
results (a deletion error). results (a deletion error).
results: results:
An iterable of tuples. The first element is the reference transcript An iterable of tuples. The first element is the cur_id, the second is
while the second element is the predicted result. the reference transcript and the third element is the predicted result.
enable_log: enable_log:
If True, also print detailed WER to the console. If True, also print detailed WER to the console.
Otherwise, it is written only to the given file. Otherwise, it is written only to the given file.
@ -389,7 +401,7 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0 num_corr = 0
ERR = "*" ERR = "*"
for ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR) ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
if ref_word == ERR: if ref_word == ERR:
@ -405,7 +417,7 @@ def write_error_stats(
else: else:
words[ref_word][0] += 1 words[ref_word][0] += 1
num_corr += 1 num_corr += 1
ref_len = sum([len(r) for r, _ in results]) ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values()) sub_errs = sum(subs.values())
ins_errs = sum(ins.values()) ins_errs = sum(ins.values())
del_errs = sum(dels.values()) del_errs = sum(dels.values())
@ -434,7 +446,7 @@ def write_error_stats(
print("", file=f) print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR) ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True combine_successive_errors = True
if combine_successive_errors: if combine_successive_errors:
@ -461,7 +473,8 @@ def write_error_stats(
] ]
print( print(
" ".join( f"{cut_id}:\t"
+ " ".join(
( (
ref_word ref_word
if ref_word == hyp_word if ref_word == hyp_word