From 4a6671240630958bdadfaeb72af640b7cdb7ea0a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 25 Jul 2021 18:21:26 +0800 Subject: [PATCH] Add LM rescoring. --- .github/workflows/test.yml | 8 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 291 ++++++++++++++++---- icefall/decode.py | 284 +++++++++++++++++++ icefall/utils.py | 55 +++- 4 files changed, 583 insertions(+), 55 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5af8a9ee6..9a298877a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: os: [ubuntu-18.04, macos-10.15] python-version: [3.6, 3.7, 3.8, 3.9] torch: ["1.8.1"] - k2-version: ["1.2.dev20210723"] + k2-version: ["1.2.dev20210724"] fail-fast: false steps: @@ -64,9 +64,7 @@ jobs: ls -lh export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH echo $PYTHONPATH - # Skip CtcTrainingGraphCompiler since it requires - # k2.ctc_topo, which has not been merged into master - pytest -k "not TestCtc" ./test + pytest ./test - name: Run tests if: startsWith(matrix.os, 'macos') @@ -76,4 +74,4 @@ jobs: lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") echo "lib_path: $lib_path" export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH - pytest -k "not TestCtc" ./test + pytest ./test diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 885ebb1fd..2a6dc671e 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -3,8 +3,9 @@ import argparse import logging +from collections import defaultdict from pathlib import Path -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple import k2 import torch @@ -13,12 +14,19 @@ from model import TdnnLstm from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.dataset.librispeech import LibriSpeechAsrDataModule -from icefall.decode import get_lattice, nbest_decoding, one_best_decoding +from icefall.decode import ( + get_lattice, + nbest_decoding, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, get_texts, setup_logger, + store_transcripts, write_error_stats, ) @@ -51,36 +59,73 @@ def get_params() -> AttributeDict: { "exp_dir": Path("tdnn_lstm_ctc/exp/"), "lang_dir": Path("data/lang"), + "lm_dir": Path("data/lm"), "feature_dim": 80, "subsampling_factor": 3, "search_beam": 20, - "output_beam": 8, + "output_beam": 5, "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - "method": "1best", # [1best, nbest] - "num_paths": 30, # used when method is nbest + # Possible values for method: + # - 1best + # - nbest + # - nbest-rescoring + # - whole-lattice-rescoring + "method": "whole-lattice-rescoring", + # num_paths is used when method is "nbest" and "nbest-rescoring" + "num_paths": 30, } ) return params -@torch.no_grad() def decode_one_batch( params: AttributeDict, model: nn.Module, HLG: k2.Fsa, batch: dict, lexicon: Lexicon, -) -> List[Tuple[List[str], List[str]]]: - """Decode one batch and return a list of tuples containing - `(ref_words, hyp_words)`. + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. Args: params: - It is the return value of :func:`get_params`. + It's the return value of :func:`get_params`. + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + model: + The neural model. + HLG: + The decoding graph. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + lexicon: + It contains word symbol table. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. """ device = HLG.device feature = batch["inputs"] @@ -114,29 +159,154 @@ def decode_one_batch( max_active_states=params.max_active_states, ) - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + ) + key = f"no_rescore-{params.num_paths}" + hyps = get_texts(best_path) + hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] + + lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, ) else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list ) - hyps = get_texts(best_path) - hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + ans = dict() + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + return ans - texts = supervisions["text"] +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + lexicon: Lexicon, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[int], List[int]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. + lexicon: + It contains word symbol table. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ results = [] - for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - results.append((ref_words, hyp_words)) + + num_cuts = 0 + tot_num_cuts = len(dl.dataset.cuts) + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + batch=batch, + lexicon=lexicon, + G=G, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + logging.info( + f"batch {batch_idx}, cuts processed until now is " + f"{num_cuts}/{tot_num_cuts} " + f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + ) return results +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) @@ -160,6 +330,45 @@ def main(): HLG = HLG.to(device) assert HLG.requires_grad is False + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.words["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt") + G = k2.Fsa.from_dict(d).to(device) + + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + model = TdnnLstm( num_features=params.feature_dim, num_classes=max_phone_id + 1, # +1 for the blank symbol @@ -188,32 +397,18 @@ def main(): # test_sets = ["test-clean", "test-other"] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): - tot_num_cuts = len(test_dl.dataset.cuts) - num_cuts = 0 + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + lexicon=lexicon, + G=G, + ) - results = [] - for batch_idx, batch in enumerate(test_dl): - this_batch = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - lexicon=lexicon, - ) - results.extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" - ) - - errs_filename = params.exp_dir / f"errs-{test_set}.txt" - with open(errs_filename, "w") as f: - write_error_stats(f, test_set, results) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/icefall/decode.py b/icefall/decode.py index ed663bce8..bb8d0c10e 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1,3 +1,6 @@ +import logging +from typing import Dict, List + import k2 import torch @@ -243,3 +246,284 @@ def nbest_decoding( best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels return best_path_fsa + + +def compute_am_scores( + lattice: k2.Fsa, + word_fsa_with_epsilon_loops: k2.Fsa, + path_to_seq_map: torch.Tensor, +) -> torch.Tensor: + """Compute AM scores of n-best lists (represented as word_fsas). + + Args: + lattice: + An FsaVec, e.g., the return value of :func:`get_lattice` + It must have the attribute `lm_scores`. + word_fsa_with_epsilon_loops: + An FsaVec representing an n-best list. Note that it has been processed + by `k2.add_epsilon_self_loops`. + path_to_seq_map: + A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates + which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to. + path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). + Returns: + Return a 1-D torch.Tensor containing the AM scores of each path. + `ans.numel() == word_fsas_with_epsilon_loops.shape[0]` + """ + assert len(lattice.shape) == 3 + assert hasattr(lattice, "lm_scores") + + # k2.compose() currently does not support b_to_a_map. To void + # replicating `lats`, we use k2.intersect_device here. + # + # lattice has token IDs as `labels` and word IDs as aux_labels, so we + # need to invert it here. + inv_lattice = k2.invert(lattice) + + # Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor) + # and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes) + + # Remove its `aux_labels` since it is not needed in the + # following computation + del inv_lattice.aux_labels + inv_lattice = k2.arc_sort(inv_lattice) + + am_path_lattice = _intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True, + ) + + am_path_lattice = k2.top_sort(k2.connect(am_path_lattice)) + + # The `scores` of every arc consists of `am_scores` and `lm_scores` + am_path_lattice.scores = am_path_lattice.scores - am_path_lattice.lm_scores + + am_scores = am_path_lattice.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + + return am_scores + + +def rescore_with_n_best_list( + lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float] +) -> Dict[str, k2.Fsa]: + """Decode using n-best list with LM rescoring. + + `lattice` is a decoding lattice with 3 axes. This function first + extracts `num_paths` paths from `lattice` for each sequence using + `k2.random_paths`. The `am_scores` of these paths are computed. + For each path, its `lm_scores` is computed using `G` (which is an LM). + The final `tot_scores` is the sum of `am_scores` and `lm_scores`. + The path with the largest `tot_scores` within a sequence is used + as the decoding output. + + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + G: + An FsaVec representing the language model (LM). Note that it + is an FsaVec, but it contains only one Fsa. + num_paths: + It is the size `n` in `n-best` list. + lm_scale_list: + A list containing lm_scale values. + Returns: + A dict of FsaVec, whose key is an lm_scale and the value is the + best decoding path for each sequence in the lattice. + """ + device = lattice.device + + assert len(lattice.shape) == 3 + assert hasattr(lattice, "aux_labels") + assert hasattr(lattice, "lm_scores") + + assert G.shape == (1, None, None) + assert G.device == device + assert hasattr(G, "aux_labels") is False + + # First, extract `num_paths` paths for each sequence. + # path is a k2.RaggedInt with axes [seq][path][arc_pos] + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + + # word_seq is a k2.RaggedInt sharing the same shape as `path` + # but it contains word IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + word_seq = k2.index(lattice.aux_labels, path) + + # Remove epsilons and -1 from word_seq + word_seq = k2.ragged.remove_values_leq(word_seq, 0) + + # Remove paths that has identical word sequences. + # + # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # except that there are no repeated paths with the same word_seq + # within a sequence. + # + # num_repeats is also a k2.RaggedInt with 2 axes containing the + # multiplicities of each path. + # num_repeats.num_elements() == unique_word_seqs.num_elements() + # + # Since k2.ragged.unique_sequences will reorder paths within a seq, + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + # new2old.numel() == unique_word_seqs.tot_size(1) + unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( + word_seq, need_num_repeats=True, need_new2old_indexes=True + ) + + seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + + # path_to_seq_map is a 1-D torch.Tensor. + # path_to_seq_map[i] is the seq to which the i-th path + # belongs. + path_to_seq_map = seq_to_path_shape.row_ids(1) + + # Remove the seq axis. + # Now unique_word_seq has only two axes [path][word] + unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + + # word_fsa is an FsaVec with axes [path][state][arc] + word_fsa = k2.linear_fsa(unique_word_seq) + + word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) + + am_scores = compute_am_scores( + lattice, word_fsa_with_epsilon_loops, path_to_seq_map + ) + + # Now compute lm_scores + b_to_a_map = torch.zeros_like(path_to_seq_map) + lm_path_lattice = _intersect_device( + G, + word_fsa_with_epsilon_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ) + lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice)) + lm_scores = lm_path_lattice.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + + path_2axes = k2.ragged.remove_axis(path, 0) + + ans = dict() + for lm_scale in lm_scale_list: + tot_scores = am_scores / lm_scale + lm_scores + + # Remember that we used `k2.ragged.unique_sequences` to remove repeated + # paths to avoid redundant computation in `k2.intersect_device`. + # Now we use `num_repeats` to correct the scores for each path. + # + # NOTE(fangjun): It is commented out as it leads to a worse WER + # tot_scores = tot_scores * num_repeats.values() + + ragged_tot_scores = k2.RaggedFloat( + seq_to_path_shape, tot_scores.to(torch.float32) + ) + argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + + # Use k2.index here since argmax_indexes' dtype is torch.int32 + best_path_indexes = k2.index(new2old, argmax_indexes) + + # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] + best_path = k2.index(path_2axes, best_path_indexes) + + # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # Note that it contains -1s. + labels = k2.index(lattice.labels.contiguous(), best_path) + + labels = k2.ragged.remove_values_eq(labels, -1) + + # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so + # aux_labels is also a k2.RaggedInt with 2 axes + aux_labels = k2.index(lattice.aux_labels, best_path.values()) + + best_path_fsa = k2.linear_fsa(labels) + best_path_fsa.aux_labels = aux_labels + + key = f"lm_scale_{lm_scale}" + ans[key] = best_path_fsa + + return ans + + +def rescore_with_whole_lattice( + lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] +) -> Dict[str, k2.Fsa]: + """Use whole lattice to rescore. + + Args: + lattice: + An FsaVec It can be the return value of :func:`get_lattice`. + G_with_epsilon_loops: + An FsaVec representing the language model (LM). Note that it + is an FsaVec, but it contains only one Fsa. + lm_scale_list: + A list containing lm_scale values. + Returns: + A dict of FsaVec, whose key is a lm_scale and the value represents the + best decoding path for each sequence in the lattice. + """ + assert len(lattice.shape) == 3 + assert hasattr(lattice, "lm_scores") + assert G_with_epsilon_loops.shape == (1, None, None) + + device = lattice.device + lattice.scores = lattice.scores - lattice.lm_scores + # We will use lm_scores from G, so remove lats.lm_scores here + del lattice.lm_scores + assert hasattr(lattice, "lm_scores") is False + + # Now, lattice.scores contains only am_scores + + # inv_lattice has word IDs as labels. + # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt + inv_lattice = k2.invert(lattice) + num_seqs = lattice.shape[0] + + b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) + while True: + try: + rescoring_lattice = k2.intersect_device( + G_with_epsilon_loops, + inv_lattice, + b_to_a_map, + sorted_match_a=True, + ) + rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice)) + break + except RuntimeError as e: + logging.info(f"Caught exception:\n{e}\n") + logging.info( + f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" + ) + + # NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here + # to avoid OOM. We may need to fine tune it. + inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True) + logging.info( + f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}" + ) + + # lat has token IDs as labels + # and word IDs as aux_labels. + lat = k2.invert(rescoring_lattice) + + ans = dict() + # + # The following implements + # scores = (scores - lm_scores)/lm_scale + lm_scores + # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) + # + saved_am_scores = lat.scores - lat.lm_scores + for lm_scale in lm_scale_list: + am_scores = saved_am_scores / lm_scale + lat.scores = am_scores + lat.lm_scores + + best_path = k2.shortest_path(lat, use_double_scores=True) + key = f"lm_scale_{lm_scale}" + ans[key] = best_path + return ans diff --git a/icefall/utils.py b/icefall/utils.py index 4d1ca6cff..1f2cf95f3 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -6,7 +6,7 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, List, TextIO, Tuple, Union +from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 import k2.ragged as k2r @@ -171,7 +171,7 @@ def encode_supervisions( def get_texts(best_paths: k2.Fsa) -> List[List[int]]: - """Extract the texts from the best-path FSAs. + """Extract the texts (as word IDs) from the best-path FSAs. Args: best_paths: A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. @@ -204,9 +204,60 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: return k2r.to_list(aux_labels) +def store_transcripts( + filename: Pathlike, texts: Iterable[Tuple[str, str]] +) -> None: + """Save predicted results and reference transcripts to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the reference transcript + while the second element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w") as f: + for ref, hyp in texts: + print(f"ref={ref}", file=f) + print(f"hyp={hyp}", file=f) + + def write_error_stats( f: TextIO, test_set_name: str, results: List[Tuple[str, str]] ) -> float: + """Write statistics based on predicted results and reference transcripts. + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted results. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, but it is + predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the reference transcript + while the second element is the predicted result. + Returns: + Return None. + """ subs: Dict[Tuple[str, str], int] = defaultdict(int) ins: Dict[str, int] = defaultdict(int) dels: Dict[str, int] = defaultdict(int)