diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index b78c600c3..a185567da 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -367,6 +367,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] texts = [list(str(text).replace(" ", "")) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -379,8 +380,8 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - this_batch.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) @@ -405,6 +406,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -528,6 +530,8 @@ def main(): from lhotse import CutSet from lhotse.dataset.webdataset import export_to_webdataset + # we need cut ids to display recognition results. + args.return_cuts = True aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) dev = "dev" diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index c38c4c65f..d860d3fdb 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -374,6 +374,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -389,9 +390,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) @@ -419,6 +420,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -537,6 +539,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True aishell = AishellAsrDataModule(args) test_cuts = aishell.test_cuts() test_dl = aishell.test_dataloaders(test_cuts) diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 35a7d98fc..f33ddd48b 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -386,6 +386,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -401,9 +402,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) @@ -431,6 +432,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -556,6 +558,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True aishell = AishellAsrDataModule(args) test_cuts = aishell.test_cuts() test_dl = aishell.test_dataloaders(test_cuts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 6aea306c8..a6a01c2c6 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -377,6 +377,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -389,9 +390,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -416,6 +417,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -606,6 +608,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True asr_datamodule = AsrDataModule(args) aishell = AIShell(manifest_dir=args.manifest_dir) test_cuts = aishell.test_cuts() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index f3c8e8f44..58a999c22 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -241,6 +241,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -253,9 +254,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) @@ -278,6 +279,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -365,6 +367,8 @@ def main(): model.to(device) model.eval() + # we need cut ids to display recognition results. + args.return_cuts = True aishell = AishellAsrDataModule(args) test_cuts = aishell.test_cuts() test_dl = aishell.test_dataloaders(test_cuts) diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index a7b030fa5..7a96b6f73 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -38,8 +38,8 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, - write_error_stats, str2bool, + write_error_stats, ) @@ -296,6 +296,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -307,9 +308,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -334,6 +335,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) # The following prints out WERs, per-word error statistics and aligned @@ -438,6 +440,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True aishell = AishellAsrDataModule(args) test_cuts = aishell.test_cuts() test_dl = aishell.test_dataloaders(test_cuts) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index 47265f846..7eb273da0 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -341,6 +341,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -353,9 +354,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -380,6 +381,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -496,6 +498,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True asr_datamodule = AsrDataModule(args) aishell = AIShell(manifest_dir=args.manifest_dir) test_cuts = aishell.test_cuts() diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 4773ebc7d..1fe39fed2 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -345,6 +345,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -357,9 +358,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -384,6 +385,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -498,6 +500,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True aishell = AishellAsrDataModule(args) test_cuts = aishell.test_cuts() test_dl = aishell.test_dataloaders(test_cuts) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index f03bd34d3..7d6f6f6d5 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -514,6 +514,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -527,8 +528,8 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - this_batch.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) @@ -553,6 +554,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -756,6 +758,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True aishell2 = AiShell2AsrDataModule(args) valid_cuts = aishell2.valid_cuts() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index d329410e1..739dfc4a1 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -378,6 +378,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] texts = [list(str(text).replace(" ", "")) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -390,8 +391,8 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - this_batch.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) @@ -416,6 +417,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -607,6 +609,8 @@ def main(): c.supervisions[0].text = text_normalize(text) return c + # we need cut ids to display recognition results. + args.return_cuts = True aishell4 = Aishell4AsrDataModule(args) test_cuts = aishell4.test_cuts() test_cuts = test_cuts.map(text_normalize_for_cut) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index cb455838e..65fc74728 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -367,6 +367,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] texts = [list(str(text).replace(" ", "")) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -379,8 +380,8 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - this_batch.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) @@ -405,6 +406,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -535,6 +537,8 @@ def main(): from lhotse import CutSet from lhotse.dataset.webdataset import export_to_webdataset + # we need cut ids to display recognition results. + args.return_cuts = True alimeeting = AlimeetingAsrDataModule(args) dev = "eval" diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 6ab9852b4..f4b438aad 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -451,6 +451,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -469,9 +470,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) else: @@ -512,6 +513,7 @@ def save_results( for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" results = post_processing(results) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -676,6 +678,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True gigaspeech = GigaSpeechAsrDataModule(args) dev_cuts = gigaspeech.dev_cuts() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index ce5116336..d7ecc3fdc 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -374,6 +374,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -386,9 +387,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -414,6 +415,7 @@ def save_results( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = post_processing(results) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -544,6 +546,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True gigaspeech = GigaSpeechAsrDataModule(args) dev_cuts = gigaspeech.dev_cuts() diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 0e8247b8d..7a8fb2130 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -525,6 +525,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -544,9 +545,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) else: @@ -586,6 +587,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -779,6 +781,8 @@ def main(): ) rnn_lm_model.eval() + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 8a4cad1ad..6b9da12a9 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -31,14 +31,13 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.decode import ( get_lattice, nbest_decoding, @@ -633,6 +632,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -652,9 +652,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) else: @@ -694,6 +694,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -956,6 +957,8 @@ def main(): ) rnn_lm_model.eval() + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index a77168b62..23372034a 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -449,6 +449,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -466,9 +467,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) @@ -496,6 +497,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") @@ -661,6 +663,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) # CAUTION: `test_sets` is for displaying only. # If you want to skip test-clean, you have to skip diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 287fb94df..a03fe2684 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -403,6 +403,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -415,9 +416,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -442,6 +443,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -624,6 +626,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 69ee7ee9a..9494e1fc1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -29,6 +29,7 @@ class Stream(object): def __init__( self, params: AttributeDict, + cut_id: str, decoding_graph: Optional[k2.Fsa] = None, device: torch.device = torch.device("cpu"), LOG_EPS: float = math.log(1e-10), @@ -44,6 +45,7 @@ class Stream(object): The device to run this stream. """ self.LOG_EPS = LOG_EPS + self.cut_id = cut_id # Containing attention caches and convolution caches self.states: Optional[ @@ -138,6 +140,10 @@ class Stream(object): """Return True if all feature frames are processed.""" return self._done + @property + def id(self) -> str: + return self.cut_id + def decoding_result(self) -> List[int]: """Obtain current decoding result.""" if self.decoding_method == "greedy_search": diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 0a6bbfa8b..61dbe8658 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -74,7 +74,6 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -from lhotse import CutSet import numpy as np import sentencepiece as spm import torch @@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from beam_search import Hypothesis, HypothesisList, get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet from stream import Stream from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -678,6 +678,7 @@ def decode_dataset( # Each utterance has a Stream. stream = Stream( params=params, + cut_id=cut.id, decoding_graph=decoding_graph, device=device, LOG_EPS=LOG_EPSILON, @@ -711,6 +712,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + streams[i].id, streams[i].ground_truth.split(), sp.decode(streams[i].decoding_result()).split(), ) @@ -731,6 +733,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + streams[i].id, streams[i].ground_truth.split(), sp.decode(streams[i].decoding_result()).split(), ) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 402ec4293..d204a9d75 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -403,6 +403,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -415,9 +416,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -442,6 +443,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -624,6 +626,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 0f687898f..71150392d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -74,7 +74,6 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -from lhotse import CutSet import numpy as np import sentencepiece as spm import torch @@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from beam_search import Hypothesis, HypothesisList, get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet from stream import Stream from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -678,6 +678,7 @@ def decode_dataset( # Each utterance has a Stream. stream = Stream( params=params, + cut_id=cut.id, decoding_graph=decoding_graph, device=device, LOG_EPS=LOG_EPSILON, @@ -711,6 +712,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + streams[i].id, streams[i].ground_truth.split(), sp.decode(streams[i].decoding_result()).split(), ) @@ -731,6 +733,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + streams[i].id, streams[i].ground_truth.split(), sp.decode(streams[i].decoding_result()).split(), ) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index e9989579b..282ce3737 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -391,6 +391,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -403,9 +404,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -430,6 +431,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -612,6 +614,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index b7558089c..ab6cf336c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -551,6 +551,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -564,9 +565,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -591,6 +592,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -631,6 +633,8 @@ def main(): LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + # we need cut ids to display recognition results. + args.return_cuts = True params = get_params() params.update(vars(args)) @@ -754,6 +758,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 6c0e9ba19..386248554 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -28,6 +28,7 @@ class DecodeStream(object): def __init__( self, params: AttributeDict, + cut_id: str, initial_states: List[torch.Tensor], decoding_graph: Optional[k2.Fsa] = None, device: torch.device = torch.device("cpu"), @@ -48,6 +49,7 @@ class DecodeStream(object): assert device == decoding_graph.device self.params = params + self.cut_id = cut_id self.LOG_EPS = math.log(1e-10) self.states = initial_states @@ -102,6 +104,10 @@ class DecodeStream(object): """Return True if all the features are processed.""" return self._done + @property + def id(self) -> str: + return self.cut_id + def set_features( self, features: torch.Tensor, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index e455627f3..d2cae4f9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -356,6 +356,7 @@ def decode_dataset( # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + cut_id=cut.id, initial_states=initial_states, decoding_graph=decoding_graph, device=device, @@ -385,6 +386,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) @@ -402,6 +404,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index e95360d1d..9a0405c57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,7 +32,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask class Conformer(EncoderInterface): @@ -155,7 +155,7 @@ class Conformer(EncoderInterface): # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert x.size(0) == lengths.max().item() src_key_padding_mask = make_pad_mask(lengths) @@ -788,7 +788,7 @@ class RelPositionalEncoding(torch.nn.Module): ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() - if torch.jit.is_tracing(): + if is_jit_tracing(): # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., # It assumes that the maximum input won't have more than # 10k frames. @@ -1015,12 +1015,12 @@ class RelPositionMultiheadAttention(nn.Module): (batch_size, num_heads, time1, n) = x.shape time2 = time1 + left_context - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert ( n == left_context + 2 * time1 - 1 ), f"{n} == {left_context} + 2 * {time1} - 1" - if torch.jit.is_tracing(): + if is_jit_tracing(): rows = torch.arange(start=time1 - 1, end=-1, step=-1) cols = torch.arange(time2) rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) @@ -1111,12 +1111,12 @@ class RelPositionMultiheadAttention(nn.Module): """ tgt_len, bsz, embed_dim = query.size() - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert embed_dim == embed_dim_to_check assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" @@ -1232,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module): src_len = k.size(0) - if key_padding_mask is not None and not torch.jit.is_tracing(): + if key_padding_mask is not None and not is_jit_tracing(): assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz ) @@ -1243,7 +1243,7 @@ class RelPositionMultiheadAttention(nn.Module): q = q.transpose(0, 1) # (batch, time1, head, d_k) pos_emb_bsz = pos_emb.size(0) - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) @@ -1280,7 +1280,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz * num_heads, tgt_len, -1 ) - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ bsz * num_heads, tgt_len, @@ -1345,7 +1345,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert list(attn_output.size()) == [ bsz * num_heads, tgt_len, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 60a948a99..34fd31e7e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -574,6 +574,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -587,9 +588,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -614,6 +615,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -777,6 +779,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index bd0df5d49..e01167285 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union - import torch import torch.nn as nn import torch.nn.functional as F from scaling import ScaledConv1d, ScaledEmbedding +from icefall.utils import is_jit_tracing + class Decoder(nn.Module): """This class modifies the stateless decoder from the following paper: @@ -80,7 +80,10 @@ class Decoder(nn.Module): self.conv = nn.Identity() def forward( - self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True + self, + y: torch.Tensor, + need_pad: bool = True # Annotation should be Union[bool, torch.Tensor] + # but, torch.jit.script does not support Union. ) -> torch.Tensor: """ Args: @@ -108,7 +111,7 @@ class Decoder(nn.Module): else: # During inference time, there is no need to do extra padding # as we only need one output - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index b2d6ed0f2..6a9d08033 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -18,6 +18,8 @@ import torch import torch.nn as nn from scaling import ScaledLinear +from icefall.utils import is_jit_tracing + class Joiner(nn.Module): def __init__( @@ -52,7 +54,7 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert encoder_out.ndim == decoder_out.ndim assert encoder_out.ndim in (2, 4) assert encoder_out.shape == decoder_out.shape diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 566b3622f..e93d6e4a3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -23,6 +23,8 @@ import torch import torch.nn as nn from torch import Tensor +from icefall.utils import is_jit_tracing + def _ntuple(n): def parse(x): @@ -152,7 +154,7 @@ class BasicNorm(torch.nn.Module): self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: - if not torch.jit.is_tracing(): + if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) @@ -424,7 +426,7 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): + if torch.jit.is_scripting() or is_jit_tracing(): return x else: return ActivationBalancerFunction.apply( @@ -473,7 +475,7 @@ class DoubleSwish(torch.nn.Module): """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): + if torch.jit.is_scripting() or is_jit_tracing(): return x * torch.sigmoid(x - 1.0) else: return DoubleSwishFunction.apply(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 79963c968..d76a03946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -358,6 +358,7 @@ def decode_dataset( # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + cut_id=cut.id, initial_states=initial_states, decoding_graph=decoding_graph, device=device, @@ -388,6 +389,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) @@ -405,6 +407,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 8d6e33e9d..5784a78ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -422,6 +422,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -434,9 +435,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -610,6 +611,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True asr_datamodule = AsrDataModule(args) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index d2605c072..72d6f656c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -745,6 +745,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -760,9 +761,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -787,6 +788,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -1067,6 +1069,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True asr_datamodule = AsrDataModule(args) librispeech = LibriSpeech(manifest_dir=args.manifest_dir) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 79b178421..992b71dd1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -25,6 +25,7 @@ life by converting them to their non-scaled version during inference. import copy import re +from typing import List import torch import torch.nn as nn @@ -54,7 +55,10 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: in_features=scaled_linear.in_features, out_features=scaled_linear.out_features, bias=True, # otherwise, it throws errors when converting to PNNX format. - device=weight.device, + # device=weight.device, # Pytorch version before v1.9.0 does not has + # this argument. Comment out for now, we will + # see if it will raise error for versions + # after v1.9.0 ) linear.weight.data.copy_(weight) @@ -164,6 +168,24 @@ def scaled_embedding_to_embedding( return embedding +# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule +# get_submodule was added to nn.Module at v1.9.0 +def get_submodule(model, target): + if target == "": + return model + atoms: List[str] = target.split(".") + mod: torch.nn.Module = model + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) + mod = getattr(mod, item) + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not " "an nn.Module") + return mod + + def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, @@ -200,7 +222,7 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): for k, v in d.items(): if "." in k: parent, child = k.rsplit(".", maxsplit=1) - setattr(model.get_submodule(parent), child, v) + setattr(get_submodule(model, parent), child, v) else: setattr(model, k, v) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 1976d19a6..10bb44e00 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -359,6 +359,7 @@ def decode_dataset( # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + cut_id=cut.id, initial_states=initial_states, decoding_graph=decoding_graph, device=device, @@ -389,6 +390,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) @@ -406,6 +408,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index d8ae8e026..8431492e6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -578,6 +578,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -591,9 +592,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -618,6 +619,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -831,6 +833,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index de89d41c2..7af9ea9b8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -371,6 +371,7 @@ def decode_dataset( # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + cut_id=cut.id, initial_states=initial_states, decoding_graph=decoding_graph, device=device, @@ -401,6 +402,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) @@ -418,6 +420,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 2d0965023..32bbd16f7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -564,6 +564,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -577,9 +578,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -604,6 +605,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -817,6 +819,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index d47d57d1b..6fee9483e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -371,6 +371,7 @@ def decode_dataset( # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + cut_id=cut.id, initial_states=initial_states, decoding_graph=decoding_graph, device=device, @@ -401,6 +402,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) @@ -418,6 +420,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): decode_results.append( ( + decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 701cad73c..2f69ba401 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -387,6 +387,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -399,9 +400,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -426,6 +427,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -608,6 +610,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 9c964c2aa..f1aacb5e7 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -311,6 +311,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -324,9 +325,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) @@ -349,6 +350,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -473,6 +475,8 @@ def main(): model.to(device) model.eval() + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 990513ed9..83e924256 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -295,6 +295,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -306,9 +307,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -333,6 +334,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -424,6 +426,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 18ae5234c..43debe643 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -292,6 +292,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -303,9 +304,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -330,6 +331,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -422,6 +424,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 5ea17b173..19a685090 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -350,6 +350,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -362,9 +363,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -389,6 +390,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -500,6 +502,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index 4cf1e559c..f48ce82f4 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -350,6 +350,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -362,9 +363,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -389,6 +390,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -500,6 +502,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 955366970..2bb6df5d6 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -351,6 +351,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -363,9 +364,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -390,6 +391,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -503,6 +505,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True asr_datamodule = AsrDataModule(args) librispeech = LibriSpeech(manifest_dir=args.manifest_dir) diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index ae49d166b..c13b980c6 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -365,6 +365,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -377,9 +378,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -405,6 +406,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -561,6 +563,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True spgispeech = SPGISpeechAsrDataModule(args) dev_cuts = spgispeech.dev_cuts() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index 305729a99..45f702163 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -453,6 +453,7 @@ def decode_dataset( zh_char = "[\u4e00-\u9fa5]+" # Chinese chars for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] zh_texts = [] en_texts = [] for i in range(len(texts)): @@ -487,14 +488,14 @@ def decode_dataset( # print(hyps_texts) hyps, zh_hyps, en_hyps = hyps_texts assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - this_batch.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) - for hyp_words, ref_text in zip(zh_hyps, zh_texts): - this_batch_zh.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, zh_hyps, zh_texts): + this_batch_zh.append((cut_id, ref_text, hyp_words)) - for hyp_words, ref_text in zip(en_hyps, en_texts): - this_batch_en.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, en_hyps, en_texts): + this_batch_en.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) zh_results[name + "_zh"].extend(this_batch_zh) @@ -521,6 +522,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -710,6 +712,8 @@ def main(): c.supervisions[0].text = text_normalize(text) return c + # we need cut ids to display recognition results. + args.return_cuts = True tal_csasr = TAL_CSASRAsrDataModule(args) dev_cuts = tal_csasr.valid_cuts() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 4d9d3c3cf..bb352baa7 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -350,6 +350,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -362,9 +363,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -389,6 +390,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -498,6 +500,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True tedlium = TedLiumAsrDataModule(args) dev_cuts = tedlium.dev_cuts() test_cuts = tedlium.test_cuts() diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index 3185e7581..c1aa2c366 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -325,6 +325,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -336,9 +337,9 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -363,6 +364,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -462,6 +464,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. + args.return_cuts = True tedlium = TedLiumAsrDataModule(args) dev_cuts = tedlium.dev_cuts() test_cuts = tedlium.test_cuts() diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py index b141e58fa..84d9f7f1b 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py @@ -311,6 +311,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -324,9 +325,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) @@ -349,6 +350,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -468,6 +470,8 @@ def main(): model.to(device) model.eval() + # we need cut ids to display recognition results. + args.return_cuts = True timit = TimitAsrDataModule(args) test_set = "TEST" test_dl = timit.test_dataloaders() diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py index e9ca96615..7672a2e1d 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py @@ -310,6 +310,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -323,9 +324,9 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch) @@ -348,6 +349,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -467,6 +469,8 @@ def main(): model.to(device) model.eval() + # we need cut ids to display recognition results. + args.return_cuts = True timit = TimitAsrDataModule(args) test_set = "TEST" test_dl = timit.test_dataloaders() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 7c06cdb3d..bbd8680b2 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -491,6 +491,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] texts = [list(str(text)) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -504,8 +505,8 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - this_batch.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) @@ -530,6 +531,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -678,6 +680,8 @@ def main(): from lhotse import CutSet from lhotse.dataset.webdataset import export_to_webdataset + # we need cut ids to display recognition results. + args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) dev = "dev" diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index ca997456f..c36df2458 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -461,6 +461,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] texts = [list(str(text)) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -473,8 +474,8 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - this_batch.append((ref_text, hyp_words)) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) @@ -499,6 +500,7 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -682,6 +684,8 @@ def main(): from lhotse import CutSet from lhotse.dataset.webdataset import export_to_webdataset + # we need cut ids to display recognition results. + args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) dev = "dev" diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py index 6c0e9ba19..386248554 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py @@ -28,6 +28,7 @@ class DecodeStream(object): def __init__( self, params: AttributeDict, + cut_id: str, initial_states: List[torch.Tensor], decoding_graph: Optional[k2.Fsa] = None, device: torch.device = torch.device("cpu"), @@ -48,6 +49,7 @@ class DecodeStream(object): assert device == decoding_graph.device self.params = params + self.cut_id = cut_id self.LOG_EPS = math.log(1e-10) self.states = initial_states @@ -102,6 +104,10 @@ class DecodeStream(object): """Return True if all the features are processed.""" return self._done + @property + def id(self) -> str: + return self.cut_id + def set_features( self, features: torch.Tensor, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 2a383ca46..ff96c6487 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -396,6 +396,7 @@ def decode_dataset( # each utterance has a DecodeStream. decode_stream = DecodeStream( params=params, + cut_id=cut.id, initial_states=initial_states, decoding_graph=decoding_graph, device=device, @@ -423,6 +424,7 @@ def decode_dataset( hyp = decode_streams[i].decoding_result() decode_results.append( ( + decode_streams[i].id, list(decode_streams[i].ground_truth), [lexicon.token_table[idx] for idx in hyp], ) @@ -441,6 +443,7 @@ def decode_dataset( hyp = decode_streams[i].decoding_result() decode_results.append( ( + decode_streams[i].id, list(decode_streams[i].ground_truth), [lexicon.token_table[idx] for idx in hyp], ) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index a6a57a2fc..79adcb14e 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -178,6 +178,7 @@ def decode_dataset( results = [] for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps = decode_one_batch( params=params, @@ -189,9 +190,9 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results.extend(this_batch) @@ -237,6 +238,7 @@ def save_results( Return None. """ recog_path = exp_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -303,6 +305,8 @@ def main(): model.to(device) model.eval() + # we need cut ids to display recognition results. + args.return_cuts = True yes_no = YesNoAsrDataModule(args) test_dl = yes_no.test_dataloaders() results = decode_dataset( diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py index abb34da4c..6714180db 100755 --- a/egs/yesno/ASR/transducer/decode.py +++ b/egs/yesno/ASR/transducer/decode.py @@ -165,6 +165,7 @@ def decode_dataset( results = [] for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps = decode_one_batch( params=params, @@ -174,9 +175,9 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) results.extend(this_batch) @@ -222,6 +223,7 @@ def save_results( Return None. """ recog_path = exp_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -291,6 +293,8 @@ def main(): model.eval() model.device = device + # we need cut ids to display recognition results. + args.return_cuts = True yes_no = YesNoAsrDataModule(args) test_dl = yes_no.test_dataloaders() results = decode_dataset( diff --git a/icefall/__init__.py b/icefall/__init__.py index 52d551c6a..0399c8459 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -49,6 +49,7 @@ from .utils import ( get_alignments, get_executor, get_texts, + is_jit_tracing, l1_norm, l2_norm, linf_norm, diff --git a/icefall/utils.py b/icefall/utils.py index f40f769f8..2b089c8d0 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -42,6 +42,18 @@ from icefall.checkpoint import average_checkpoints Pathlike = Union[str, Path] +# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379 +# Fixed: https://github.com/pytorch/pytorch/pull/49853 +# The fix was included in v1.9.0 +# https://github.com/pytorch/pytorch/releases/tag/v1.9.0 +def is_jit_tracing(): + if torch.jit.is_scripting(): + return False + elif torch.jit.is_tracing(): + return True + return False + + @contextmanager def get_executor(): # We'll either return a process pool or a distributed worker pool. @@ -321,7 +333,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: def store_transcripts( - filename: Pathlike, texts: Iterable[Tuple[str, str]] + filename: Pathlike, texts: Iterable[Tuple[str, str, str]] ) -> None: """Save predicted results and reference transcripts to a file. @@ -329,15 +341,15 @@ def store_transcripts( filename: File to save the results to. texts: - An iterable of tuples. The first element is the reference transcript - while the second element is the predicted result. + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. Returns: Return None. """ with open(filename, "w") as f: - for ref, hyp in texts: - print(f"ref={ref}", file=f) - print(f"hyp={hyp}", file=f) + for cut_id, ref, hyp in texts: + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) def write_error_stats( @@ -372,8 +384,8 @@ def write_error_stats( The reference word `SIR` is missing in the predicted results (a deletion error). results: - An iterable of tuples. The first element is the reference transcript - while the second element is the predicted result. + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. enable_log: If True, also print detailed WER to the console. Otherwise, it is written only to the given file. @@ -389,7 +401,7 @@ def write_error_stats( words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) num_corr = 0 ERR = "*" - for ref, hyp in results: + for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR) for ref_word, hyp_word in ali: if ref_word == ERR: @@ -405,7 +417,7 @@ def write_error_stats( else: words[ref_word][0] += 1 num_corr += 1 - ref_len = sum([len(r) for r, _ in results]) + ref_len = sum([len(r) for _, r, _ in results]) sub_errs = sum(subs.values()) ins_errs = sum(ins.values()) del_errs = sum(dels.values()) @@ -434,7 +446,7 @@ def write_error_stats( print("", file=f) print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) - for ref, hyp in results: + for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR) combine_successive_errors = True if combine_successive_errors: @@ -461,7 +473,8 @@ def write_error_stats( ] print( - " ".join( + f"{cut_id}:\t" + + " ".join( ( ref_word if ref_word == hyp_word