mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add LM rescoring.
This commit is contained in:
parent
6f9fe5b906
commit
4a66712406
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@ -32,7 +32,7 @@ jobs:
|
||||
os: [ubuntu-18.04, macos-10.15]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.2.dev20210723"]
|
||||
k2-version: ["1.2.dev20210724"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@ -64,9 +64,7 @@ jobs:
|
||||
ls -lh
|
||||
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
|
||||
echo $PYTHONPATH
|
||||
# Skip CtcTrainingGraphCompiler since it requires
|
||||
# k2.ctc_topo, which has not been merged into master
|
||||
pytest -k "not TestCtc" ./test
|
||||
pytest ./test
|
||||
|
||||
- name: Run tests
|
||||
if: startsWith(matrix.os, 'macos')
|
||||
@ -76,4 +74,4 @@ jobs:
|
||||
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
||||
echo "lib_path: $lib_path"
|
||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
||||
pytest -k "not TestCtc" ./test
|
||||
pytest ./test
|
||||
|
@ -3,8 +3,9 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -13,12 +14,19 @@ from model import TdnnLstm
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -51,36 +59,73 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 3,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"output_beam": 5,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"method": "1best", # [1best, nbest]
|
||||
"num_paths": 30, # used when method is nbest
|
||||
# Possible values for method:
|
||||
# - 1best
|
||||
# - nbest
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
"method": "whole-lattice-rescoring",
|
||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
||||
"num_paths": 30,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
) -> List[Tuple[List[str], List[str]]]:
|
||||
"""Decode one batch and return a list of tuples containing
|
||||
`(ref_words, hyp_words)`.
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It is the return value of :func:`get_params`.
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = HLG.device
|
||||
feature = batch["inputs"]
|
||||
@ -114,29 +159,154 @@ def decode_one_batch(
|
||||
max_active_states=params.max_active_states,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
key = f"no_rescore-{params.num_paths}"
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
|
||||
|
||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
ans = dict()
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
texts = supervisions["text"]
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
lexicon: Lexicon,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
results = []
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
results.append((ref_words, hyp_words))
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
@ -160,6 +330,45 @@ def main():
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.words["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
model = TdnnLstm(
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
@ -188,32 +397,18 @@ def main():
|
||||
#
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
tot_num_cuts = len(test_dl.dataset.cuts)
|
||||
num_cuts = 0
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
)
|
||||
|
||||
results = []
|
||||
for batch_idx, batch in enumerate(test_dl):
|
||||
this_batch = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
)
|
||||
results.extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
)
|
||||
|
||||
errs_filename = params.exp_dir / f"errs-{test_set}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
write_error_stats(f, test_set, results)
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
@ -1,3 +1,6 @@
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
@ -243,3 +246,284 @@ def nbest_decoding(
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
best_path_fsa.aux_labels = aux_labels
|
||||
return best_path_fsa
|
||||
|
||||
|
||||
def compute_am_scores(
|
||||
lattice: k2.Fsa,
|
||||
word_fsa_with_epsilon_loops: k2.Fsa,
|
||||
path_to_seq_map: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute AM scores of n-best lists (represented as word_fsas).
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec, e.g., the return value of :func:`get_lattice`
|
||||
It must have the attribute `lm_scores`.
|
||||
word_fsa_with_epsilon_loops:
|
||||
An FsaVec representing an n-best list. Note that it has been processed
|
||||
by `k2.add_epsilon_self_loops`.
|
||||
path_to_seq_map:
|
||||
A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
|
||||
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
|
||||
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
|
||||
Returns:
|
||||
Return a 1-D torch.Tensor containing the AM scores of each path.
|
||||
`ans.numel() == word_fsas_with_epsilon_loops.shape[0]`
|
||||
"""
|
||||
assert len(lattice.shape) == 3
|
||||
assert hasattr(lattice, "lm_scores")
|
||||
|
||||
# k2.compose() currently does not support b_to_a_map. To void
|
||||
# replicating `lats`, we use k2.intersect_device here.
|
||||
#
|
||||
# lattice has token IDs as `labels` and word IDs as aux_labels, so we
|
||||
# need to invert it here.
|
||||
inv_lattice = k2.invert(lattice)
|
||||
|
||||
# Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
|
||||
# and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
|
||||
|
||||
# Remove its `aux_labels` since it is not needed in the
|
||||
# following computation
|
||||
del inv_lattice.aux_labels
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
|
||||
am_path_lattice = _intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_seq_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
|
||||
am_path_lattice = k2.top_sort(k2.connect(am_path_lattice))
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
am_path_lattice.scores = am_path_lattice.scores - am_path_lattice.lm_scores
|
||||
|
||||
am_scores = am_path_lattice.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
|
||||
return am_scores
|
||||
|
||||
|
||||
def rescore_with_n_best_list(
|
||||
lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float]
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""Decode using n-best list with LM rescoring.
|
||||
|
||||
`lattice` is a decoding lattice with 3 axes. This function first
|
||||
extracts `num_paths` paths from `lattice` for each sequence using
|
||||
`k2.random_paths`. The `am_scores` of these paths are computed.
|
||||
For each path, its `lm_scores` is computed using `G` (which is an LM).
|
||||
The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
|
||||
The path with the largest `tot_scores` within a sequence is used
|
||||
as the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||
G:
|
||||
An FsaVec representing the language model (LM). Note that it
|
||||
is an FsaVec, but it contains only one Fsa.
|
||||
num_paths:
|
||||
It is the size `n` in `n-best` list.
|
||||
lm_scale_list:
|
||||
A list containing lm_scale values.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key is an lm_scale and the value is the
|
||||
best decoding path for each sequence in the lattice.
|
||||
"""
|
||||
device = lattice.device
|
||||
|
||||
assert len(lattice.shape) == 3
|
||||
assert hasattr(lattice, "aux_labels")
|
||||
assert hasattr(lattice, "lm_scores")
|
||||
|
||||
assert G.shape == (1, None, None)
|
||||
assert G.device == device
|
||||
assert hasattr(G, "aux_labels") is False
|
||||
|
||||
# First, extract `num_paths` paths for each sequence.
|
||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||
|
||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
word_seq = k2.index(lattice.aux_labels, path)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||
|
||||
# Remove paths that has identical word sequences.
|
||||
#
|
||||
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
|
||||
# except that there are no repeated paths with the same word_seq
|
||||
# within a sequence.
|
||||
#
|
||||
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||
# multiplicities of each path.
|
||||
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
||||
#
|
||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||
# to the input path index.
|
||||
# new2old.numel() == unique_word_seqs.tot_size(1)
|
||||
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
||||
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
||||
)
|
||||
|
||||
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
||||
|
||||
# path_to_seq_map is a 1-D torch.Tensor.
|
||||
# path_to_seq_map[i] is the seq to which the i-th path
|
||||
# belongs.
|
||||
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
||||
|
||||
# Remove the seq axis.
|
||||
# Now unique_word_seq has only two axes [path][word]
|
||||
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
||||
|
||||
# word_fsa is an FsaVec with axes [path][state][arc]
|
||||
word_fsa = k2.linear_fsa(unique_word_seq)
|
||||
|
||||
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||
|
||||
am_scores = compute_am_scores(
|
||||
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
||||
)
|
||||
|
||||
# Now compute lm_scores
|
||||
b_to_a_map = torch.zeros_like(path_to_seq_map)
|
||||
lm_path_lattice = _intersect_device(
|
||||
G,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=b_to_a_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice))
|
||||
lm_scores = lm_path_lattice.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
|
||||
path_2axes = k2.ragged.remove_axis(path, 0)
|
||||
|
||||
ans = dict()
|
||||
for lm_scale in lm_scale_list:
|
||||
tot_scores = am_scores / lm_scale + lm_scores
|
||||
|
||||
# Remember that we used `k2.ragged.unique_sequences` to remove repeated
|
||||
# paths to avoid redundant computation in `k2.intersect_device`.
|
||||
# Now we use `num_repeats` to correct the scores for each path.
|
||||
#
|
||||
# NOTE(fangjun): It is commented out as it leads to a worse WER
|
||||
# tot_scores = tot_scores * num_repeats.values()
|
||||
|
||||
ragged_tot_scores = k2.RaggedFloat(
|
||||
seq_to_path_shape, tot_scores.to(torch.float32)
|
||||
)
|
||||
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
||||
|
||||
# Use k2.index here since argmax_indexes' dtype is torch.int32
|
||||
best_path_indexes = k2.index(new2old, argmax_indexes)
|
||||
|
||||
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
|
||||
best_path = k2.index(path_2axes, best_path_indexes)
|
||||
|
||||
# labels is a k2.RaggedInt with 2 axes [path][phone_id]
|
||||
# Note that it contains -1s.
|
||||
labels = k2.index(lattice.labels.contiguous(), best_path)
|
||||
|
||||
labels = k2.ragged.remove_values_eq(labels, -1)
|
||||
|
||||
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
|
||||
# aux_labels is also a k2.RaggedInt with 2 axes
|
||||
aux_labels = k2.index(lattice.aux_labels, best_path.values())
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
best_path_fsa.aux_labels = aux_labels
|
||||
|
||||
key = f"lm_scale_{lm_scale}"
|
||||
ans[key] = best_path_fsa
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_whole_lattice(
|
||||
lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float]
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""Use whole lattice to rescore.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec It can be the return value of :func:`get_lattice`.
|
||||
G_with_epsilon_loops:
|
||||
An FsaVec representing the language model (LM). Note that it
|
||||
is an FsaVec, but it contains only one Fsa.
|
||||
lm_scale_list:
|
||||
A list containing lm_scale values.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key is a lm_scale and the value represents the
|
||||
best decoding path for each sequence in the lattice.
|
||||
"""
|
||||
assert len(lattice.shape) == 3
|
||||
assert hasattr(lattice, "lm_scores")
|
||||
assert G_with_epsilon_loops.shape == (1, None, None)
|
||||
|
||||
device = lattice.device
|
||||
lattice.scores = lattice.scores - lattice.lm_scores
|
||||
# We will use lm_scores from G, so remove lats.lm_scores here
|
||||
del lattice.lm_scores
|
||||
assert hasattr(lattice, "lm_scores") is False
|
||||
|
||||
# Now, lattice.scores contains only am_scores
|
||||
|
||||
# inv_lattice has word IDs as labels.
|
||||
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
|
||||
inv_lattice = k2.invert(lattice)
|
||||
num_seqs = lattice.shape[0]
|
||||
|
||||
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
||||
while True:
|
||||
try:
|
||||
rescoring_lattice = k2.intersect_device(
|
||||
G_with_epsilon_loops,
|
||||
inv_lattice,
|
||||
b_to_a_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f"Caught exception:\n{e}\n")
|
||||
logging.info(
|
||||
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
||||
)
|
||||
|
||||
# NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here
|
||||
# to avoid OOM. We may need to fine tune it.
|
||||
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True)
|
||||
logging.info(
|
||||
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
|
||||
)
|
||||
|
||||
# lat has token IDs as labels
|
||||
# and word IDs as aux_labels.
|
||||
lat = k2.invert(rescoring_lattice)
|
||||
|
||||
ans = dict()
|
||||
#
|
||||
# The following implements
|
||||
# scores = (scores - lm_scores)/lm_scale + lm_scores
|
||||
# = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
|
||||
#
|
||||
saved_am_scores = lat.scores - lat.lm_scores
|
||||
for lm_scale in lm_scale_list:
|
||||
am_scores = saved_am_scores / lm_scale
|
||||
lat.scores = am_scores + lat.lm_scores
|
||||
|
||||
best_path = k2.shortest_path(lat, use_double_scores=True)
|
||||
key = f"lm_scale_{lm_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, TextIO, Tuple, Union
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.ragged as k2r
|
||||
@ -171,7 +171,7 @@ def encode_supervisions(
|
||||
|
||||
|
||||
def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
"""Extract the texts from the best-path FSAs.
|
||||
"""Extract the texts (as word IDs) from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
@ -204,9 +204,60 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
return k2r.to_list(aux_labels)
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the reference transcript
|
||||
while the second element is the predicted result.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
for ref, hyp in texts:
|
||||
print(f"ref={ref}", file=f)
|
||||
print(f"hyp={hyp}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted results.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`, but it is
|
||||
predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the reference transcript
|
||||
while the second element is the predicted result.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
Loading…
x
Reference in New Issue
Block a user