mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
872 lines
30 KiB
Python
872 lines
30 KiB
Python
import logging
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import k2
|
|
import kaldialign
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from icefall.lexicon import Lexicon
|
|
|
|
|
|
def _get_random_paths(
|
|
lattice: k2.Fsa,
|
|
num_paths: int,
|
|
use_double_scores: bool = True,
|
|
scale: float = 1.0,
|
|
):
|
|
"""
|
|
Args:
|
|
lattice:
|
|
The decoding lattice, returned by :func:`get_lattice`.
|
|
num_paths:
|
|
It specifies the size `n` in n-best. Note: Paths are selected randomly
|
|
and those containing identical word sequences are remove dand only one
|
|
of them is kept.
|
|
use_double_scores:
|
|
True to use double precision floating point in the computation.
|
|
False to use single precision.
|
|
scale:
|
|
It's the scale applied to the lattice.scores. A smaller value
|
|
yields more unique paths.
|
|
Returns:
|
|
Return a k2.RaggedInt with 3 axes [seq][path][arc_pos]
|
|
"""
|
|
saved_scores = lattice.scores.clone()
|
|
lattice.scores *= scale
|
|
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
|
lattice.scores = saved_scores
|
|
return path
|
|
|
|
|
|
def _intersect_device(
|
|
a_fsas: k2.Fsa,
|
|
b_fsas: k2.Fsa,
|
|
b_to_a_map: torch.Tensor,
|
|
sorted_match_a: bool,
|
|
batch_size: int = 50,
|
|
) -> k2.Fsa:
|
|
"""This is a wrapper of k2.intersect_device and its purpose is to split
|
|
b_fsas into several batches and process each batch separately to avoid
|
|
CUDA OOM error.
|
|
|
|
The arguments and return value of this function are the same as
|
|
k2.intersect_device.
|
|
"""
|
|
num_fsas = b_fsas.shape[0]
|
|
if num_fsas <= batch_size:
|
|
return k2.intersect_device(
|
|
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
|
|
)
|
|
|
|
num_batches = (num_fsas + batch_size - 1) // batch_size
|
|
splits = []
|
|
for i in range(num_batches):
|
|
start = i * batch_size
|
|
end = min(start + batch_size, num_fsas)
|
|
splits.append((start, end))
|
|
|
|
ans = []
|
|
for start, end in splits:
|
|
indexes = torch.arange(start, end).to(b_to_a_map)
|
|
|
|
fsas = k2.index(b_fsas, indexes)
|
|
b_to_a = k2.index(b_to_a_map, indexes)
|
|
path_lattice = k2.intersect_device(
|
|
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
|
|
)
|
|
ans.append(path_lattice)
|
|
|
|
return k2.cat(ans)
|
|
|
|
|
|
def get_lattice(
|
|
nnet_output: torch.Tensor,
|
|
HLG: k2.Fsa,
|
|
supervision_segments: torch.Tensor,
|
|
search_beam: float,
|
|
output_beam: float,
|
|
min_active_states: int,
|
|
max_active_states: int,
|
|
subsampling_factor: int = 1,
|
|
) -> k2.Fsa:
|
|
"""Get the decoding lattice from a decoding graph and neural
|
|
network output.
|
|
|
|
Args:
|
|
nnet_output:
|
|
It is the output of a neural model of shape `[N, T, C]`.
|
|
HLG:
|
|
An Fsa, the decoding graph. See also `compile_HLG.py`.
|
|
supervision_segments:
|
|
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
|
|
Each row contains information for a supervision segment. Column 0
|
|
is the `sequence_index` indicating which sequence this segment
|
|
comes from; column 1 specifies the `start_frame` of this segment
|
|
within the sequence; column 2 contains the `duration` of this
|
|
segment.
|
|
search_beam:
|
|
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
|
|
(less pruning). This is the default value; it may be modified by
|
|
`min_active_states` and `max_active_states`.
|
|
output_beam:
|
|
Beam to prune output, similar to lattice-beam in Kaldi. Relative
|
|
to best path of output.
|
|
min_active_states:
|
|
Minimum number of FSA states that are allowed to be active on any given
|
|
frame for any given intersection/composition task. This is advisory,
|
|
in that it will try not to have fewer than this number active.
|
|
Set it to zero if there is no constraint.
|
|
max_active_states:
|
|
Maximum number of FSA states that are allowed to be active on any given
|
|
frame for any given intersection/composition task. This is advisory,
|
|
in that it will try not to exceed that but may not always succeed.
|
|
You can use a very large number if no constraint is needed.
|
|
subsampling_factor:
|
|
The subsampling factor of the model.
|
|
Returns:
|
|
A lattice containing the decoding result.
|
|
"""
|
|
dense_fsa_vec = k2.DenseFsaVec(
|
|
nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
|
|
)
|
|
|
|
lattice = k2.intersect_dense_pruned(
|
|
HLG,
|
|
dense_fsa_vec,
|
|
search_beam=search_beam,
|
|
output_beam=output_beam,
|
|
min_active_states=min_active_states,
|
|
max_active_states=max_active_states,
|
|
)
|
|
|
|
return lattice
|
|
|
|
|
|
def one_best_decoding(
|
|
lattice: k2.Fsa, use_double_scores: bool = True
|
|
) -> k2.Fsa:
|
|
"""Get the best path from a lattice.
|
|
|
|
Args:
|
|
lattice:
|
|
The decoding lattice returned by :func:`get_lattice`.
|
|
use_double_scores:
|
|
True to use double precision floating point in the computation.
|
|
False to use single precision.
|
|
Return:
|
|
An FsaVec containing linear paths.
|
|
"""
|
|
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
|
|
return best_path
|
|
|
|
|
|
def nbest_decoding(
|
|
lattice: k2.Fsa,
|
|
num_paths: int,
|
|
use_double_scores: bool = True,
|
|
scale: float = 1.0,
|
|
) -> k2.Fsa:
|
|
"""It implements something like CTC prefix beam search using n-best lists.
|
|
|
|
The basic idea is to first extra n-best paths from the given lattice,
|
|
build a word seqs from these paths, and compute the total scores
|
|
of these sequences in the log-semiring. The one with the max score
|
|
is used as the decoding output.
|
|
|
|
Caution:
|
|
Don't be confused by `best` in the name `n-best`. Paths are selected
|
|
randomly, not by ranking their scores.
|
|
|
|
Args:
|
|
lattice:
|
|
The decoding lattice, returned by :func:`get_lattice`.
|
|
num_paths:
|
|
It specifies the size `n` in n-best. Note: Paths are selected randomly
|
|
and those containing identical word sequences are remove dand only one
|
|
of them is kept.
|
|
use_double_scores:
|
|
True to use double precision floating point in the computation.
|
|
False to use single precision.
|
|
scale:
|
|
It's the scale applied to the lattice.scores. A smaller value
|
|
yields more unique paths.
|
|
Returns:
|
|
An FsaVec containing linear FSAs.
|
|
"""
|
|
path = _get_random_paths(
|
|
lattice=lattice,
|
|
num_paths=num_paths,
|
|
use_double_scores=use_double_scores,
|
|
scale=scale,
|
|
)
|
|
|
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
|
# The last entry in each sublist is -1.
|
|
word_seq = k2.index(lattice.aux_labels, path)
|
|
# Note: the above operation supports also the case when
|
|
# lattice.aux_labels is a ragged tensor. In that case,
|
|
# `remove_axis=True` is used inside the pybind11 binding code,
|
|
# so the resulting `word_seq` still has 3 axes, like `path`.
|
|
# The 3 axes are [seq][path][word_id]
|
|
|
|
# Remove 0 (epsilon) and -1 from word_seq
|
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
|
|
|
# Remove sequences with identical word sequences.
|
|
#
|
|
# k2.ragged.unique_sequences will reorder paths within a seq.
|
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
|
# to the input path index.
|
|
# new2old.numel() == unique_word_seqs.tot_size(1)
|
|
unique_word_seq, _, new2old = k2.ragged.unique_sequences(
|
|
word_seq, need_num_repeats=False, need_new2old_indexes=True
|
|
)
|
|
# Note: unique_word_seq still has the same axes as word_seq
|
|
|
|
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
|
|
|
# path_to_seq_map is a 1-D torch.Tensor.
|
|
# path_to_seq_map[i] is the seq to which the i-th path belongs
|
|
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
|
|
|
# Remove the seq axis.
|
|
# Now unique_word_seq has only two axes [path][word]
|
|
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
|
|
|
# word_fsa is an FsaVec with axes [path][state][arc]
|
|
word_fsa = k2.linear_fsa(unique_word_seq)
|
|
|
|
# add epsilon self loops since we will use
|
|
# k2.intersect_device, which treats epsilon as a normal symbol
|
|
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
|
|
|
# lattice has token IDs as labels and word IDs as aux_labels.
|
|
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
|
inv_lattice = k2.invert(lattice)
|
|
inv_lattice = k2.arc_sort(inv_lattice)
|
|
|
|
path_lattice = _intersect_device(
|
|
inv_lattice,
|
|
word_fsa_with_epsilon_loops,
|
|
b_to_a_map=path_to_seq_map,
|
|
sorted_match_a=True,
|
|
)
|
|
# path_lat has word IDs as labels and token IDs as aux_labels
|
|
|
|
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
|
|
|
tot_scores = path_lattice.get_tot_scores(
|
|
use_double_scores=use_double_scores, log_semiring=False
|
|
)
|
|
|
|
# RaggedFloat currently supports float32 only.
|
|
# If Ragged<double> is wrapped, we can use k2.RaggedDouble here
|
|
ragged_tot_scores = k2.RaggedFloat(
|
|
seq_to_path_shape, tot_scores.to(torch.float32)
|
|
)
|
|
|
|
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
|
|
|
# Since we invoked `k2.ragged.unique_sequences`, which reorders
|
|
# the index from `path`, we use `new2old` here to convert argmax_indexes
|
|
# to the indexes into `path`.
|
|
#
|
|
# Use k2.index here since argmax_indexes' dtype is torch.int32
|
|
best_path_indexes = k2.index(new2old, argmax_indexes)
|
|
|
|
path_2axes = k2.ragged.remove_axis(path, 0)
|
|
|
|
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
|
|
best_path = k2.index(path_2axes, best_path_indexes)
|
|
|
|
# labels is a k2.RaggedInt with 2 axes [path][token_id]
|
|
# Note that it contains -1s.
|
|
labels = k2.index(lattice.labels.contiguous(), best_path)
|
|
|
|
labels = k2.ragged.remove_values_eq(labels, -1)
|
|
|
|
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
|
|
# aux_labels is also a k2.RaggedInt with 2 axes
|
|
aux_labels = k2.index(lattice.aux_labels, best_path.values())
|
|
|
|
best_path_fsa = k2.linear_fsa(labels)
|
|
best_path_fsa.aux_labels = aux_labels
|
|
return best_path_fsa
|
|
|
|
|
|
def compute_am_and_lm_scores(
|
|
lattice: k2.Fsa,
|
|
word_fsa_with_epsilon_loops: k2.Fsa,
|
|
path_to_seq_map: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Compute AM scores of n-best lists (represented as word_fsas).
|
|
|
|
Args:
|
|
lattice:
|
|
An FsaVec, e.g., the return value of :func:`get_lattice`
|
|
It must have the attribute `lm_scores`.
|
|
word_fsa_with_epsilon_loops:
|
|
An FsaVec representing an n-best list. Note that it has been processed
|
|
by `k2.add_epsilon_self_loops`.
|
|
path_to_seq_map:
|
|
A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
|
|
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
|
|
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
|
|
Returns:
|
|
Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
|
|
Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
|
|
"""
|
|
assert len(lattice.shape) == 3
|
|
assert hasattr(lattice, "lm_scores")
|
|
|
|
# k2.compose() currently does not support b_to_a_map. To void
|
|
# replicating `lats`, we use k2.intersect_device here.
|
|
#
|
|
# lattice has token IDs as `labels` and word IDs as aux_labels, so we
|
|
# need to invert it here.
|
|
inv_lattice = k2.invert(lattice)
|
|
|
|
# Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
|
|
# and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
|
|
|
|
# Remove its `aux_labels` since it is not needed in the
|
|
# following computation
|
|
del inv_lattice.aux_labels
|
|
inv_lattice = k2.arc_sort(inv_lattice)
|
|
|
|
path_lattice = _intersect_device(
|
|
inv_lattice,
|
|
word_fsa_with_epsilon_loops,
|
|
b_to_a_map=path_to_seq_map,
|
|
sorted_match_a=True,
|
|
)
|
|
|
|
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
|
|
|
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
|
path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
|
|
|
|
am_scores = path_lattice.get_tot_scores(
|
|
use_double_scores=True, log_semiring=False
|
|
)
|
|
|
|
path_lattice.scores = path_lattice.lm_scores
|
|
|
|
lm_scores = path_lattice.get_tot_scores(
|
|
use_double_scores=True, log_semiring=False
|
|
)
|
|
|
|
return am_scores.to(torch.float32), lm_scores.to(torch.float32)
|
|
|
|
|
|
def rescore_with_n_best_list(
|
|
lattice: k2.Fsa,
|
|
G: k2.Fsa,
|
|
num_paths: int,
|
|
lm_scale_list: List[float],
|
|
scale: float = 1.0,
|
|
) -> Dict[str, k2.Fsa]:
|
|
"""Decode using n-best list with LM rescoring.
|
|
|
|
`lattice` is a decoding lattice with 3 axes. This function first
|
|
extracts `num_paths` paths from `lattice` for each sequence using
|
|
`k2.random_paths`. The `am_scores` of these paths are computed.
|
|
For each path, its `lm_scores` is computed using `G` (which is an LM).
|
|
The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
|
|
The path with the largest `tot_scores` within a sequence is used
|
|
as the decoding output.
|
|
|
|
Args:
|
|
lattice:
|
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
|
G:
|
|
An FsaVec representing the language model (LM). Note that it
|
|
is an FsaVec, but it contains only one Fsa.
|
|
num_paths:
|
|
It is the size `n` in `n-best` list.
|
|
lm_scale_list:
|
|
A list containing lm_scale values.
|
|
scale:
|
|
It's the scale applied to the lattice.scores. A smaller value
|
|
yields more unique paths.
|
|
Returns:
|
|
A dict of FsaVec, whose key is an lm_scale and the value is the
|
|
best decoding path for each sequence in the lattice.
|
|
"""
|
|
device = lattice.device
|
|
|
|
assert len(lattice.shape) == 3
|
|
assert hasattr(lattice, "aux_labels")
|
|
assert hasattr(lattice, "lm_scores")
|
|
|
|
assert G.shape == (1, None, None)
|
|
assert G.device == device
|
|
assert hasattr(G, "aux_labels") is False
|
|
|
|
path = _get_random_paths(
|
|
lattice=lattice,
|
|
num_paths=num_paths,
|
|
use_double_scores=True,
|
|
scale=scale,
|
|
)
|
|
|
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
|
# The last entry in each sublist is -1.
|
|
word_seq = k2.index(lattice.aux_labels, path)
|
|
|
|
# Remove epsilons and -1 from word_seq
|
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
|
|
|
# Remove paths that has identical word sequences.
|
|
#
|
|
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
|
|
# except that there are no repeated paths with the same word_seq
|
|
# within a sequence.
|
|
#
|
|
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
|
# multiplicities of each path.
|
|
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
|
|
#
|
|
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
|
# to the input path index.
|
|
# new2old.numel() == unique_word_seqs.tot_size(1)
|
|
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
|
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
|
)
|
|
|
|
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
|
|
|
# path_to_seq_map is a 1-D torch.Tensor.
|
|
# path_to_seq_map[i] is the seq to which the i-th path
|
|
# belongs.
|
|
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
|
|
|
# Remove the seq axis.
|
|
# Now unique_word_seq has only two axes [path][word]
|
|
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
|
|
|
# word_fsa is an FsaVec with axes [path][state][arc]
|
|
word_fsa = k2.linear_fsa(unique_word_seq)
|
|
|
|
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
|
|
|
am_scores, _ = compute_am_and_lm_scores(
|
|
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
|
)
|
|
|
|
# Now compute lm_scores
|
|
b_to_a_map = torch.zeros_like(path_to_seq_map)
|
|
lm_path_lattice = _intersect_device(
|
|
G,
|
|
word_fsa_with_epsilon_loops,
|
|
b_to_a_map=b_to_a_map,
|
|
sorted_match_a=True,
|
|
)
|
|
lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice))
|
|
lm_scores = lm_path_lattice.get_tot_scores(
|
|
use_double_scores=True, log_semiring=False
|
|
)
|
|
|
|
path_2axes = k2.ragged.remove_axis(path, 0)
|
|
|
|
ans = dict()
|
|
for lm_scale in lm_scale_list:
|
|
tot_scores = am_scores / lm_scale + lm_scores
|
|
|
|
# Remember that we used `k2.ragged.unique_sequences` to remove repeated
|
|
# paths to avoid redundant computation in `k2.intersect_device`.
|
|
# Now we use `num_repeats` to correct the scores for each path.
|
|
#
|
|
# NOTE(fangjun): It is commented out as it leads to a worse WER
|
|
# tot_scores = tot_scores * num_repeats.values()
|
|
|
|
ragged_tot_scores = k2.RaggedFloat(
|
|
seq_to_path_shape, tot_scores.to(torch.float32)
|
|
)
|
|
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
|
|
|
# Use k2.index here since argmax_indexes' dtype is torch.int32
|
|
best_path_indexes = k2.index(new2old, argmax_indexes)
|
|
|
|
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
|
|
best_path = k2.index(path_2axes, best_path_indexes)
|
|
|
|
# labels is a k2.RaggedInt with 2 axes [path][phone_id]
|
|
# Note that it contains -1s.
|
|
labels = k2.index(lattice.labels.contiguous(), best_path)
|
|
|
|
labels = k2.ragged.remove_values_eq(labels, -1)
|
|
|
|
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
|
|
# aux_labels is also a k2.RaggedInt with 2 axes
|
|
aux_labels = k2.index(lattice.aux_labels, best_path.values())
|
|
|
|
best_path_fsa = k2.linear_fsa(labels)
|
|
best_path_fsa.aux_labels = aux_labels
|
|
|
|
key = f"lm_scale_{lm_scale}"
|
|
ans[key] = best_path_fsa
|
|
|
|
return ans
|
|
|
|
|
|
def rescore_with_whole_lattice(
|
|
lattice: k2.Fsa,
|
|
G_with_epsilon_loops: k2.Fsa,
|
|
lm_scale_list: Optional[List[float]] = None,
|
|
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
|
|
"""Use whole lattice to rescore.
|
|
|
|
Args:
|
|
lattice:
|
|
An FsaVec It can be the return value of :func:`get_lattice`.
|
|
G_with_epsilon_loops:
|
|
An FsaVec representing the language model (LM). Note that it
|
|
is an FsaVec, but it contains only one Fsa.
|
|
lm_scale_list:
|
|
A list containing lm_scale values or None.
|
|
Returns:
|
|
If lm_scale_list is not None, return a dict of FsaVec, whose key
|
|
is a lm_scale and the value represents the best decoding path for
|
|
each sequence in the lattice.
|
|
If lm_scale_list is not None, return a lattice that is rescored
|
|
with the given LM.
|
|
"""
|
|
assert len(lattice.shape) == 3
|
|
assert hasattr(lattice, "lm_scores")
|
|
assert G_with_epsilon_loops.shape == (1, None, None)
|
|
|
|
device = lattice.device
|
|
lattice.scores = lattice.scores - lattice.lm_scores
|
|
# We will use lm_scores from G, so remove lats.lm_scores here
|
|
del lattice.lm_scores
|
|
assert hasattr(lattice, "lm_scores") is False
|
|
|
|
assert hasattr(G_with_epsilon_loops, "lm_scores")
|
|
|
|
# Now, lattice.scores contains only am_scores
|
|
|
|
# inv_lattice has word IDs as labels.
|
|
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
|
|
inv_lattice = k2.invert(lattice)
|
|
num_seqs = lattice.shape[0]
|
|
|
|
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
|
while True:
|
|
try:
|
|
rescoring_lattice = k2.intersect_device(
|
|
G_with_epsilon_loops,
|
|
inv_lattice,
|
|
b_to_a_map,
|
|
sorted_match_a=True,
|
|
)
|
|
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
|
|
break
|
|
except RuntimeError as e:
|
|
logging.info(f"Caught exception:\n{e}\n")
|
|
logging.info(
|
|
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
|
)
|
|
|
|
# NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here
|
|
# to avoid OOM. We may need to fine tune it.
|
|
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True)
|
|
logging.info(
|
|
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
|
|
)
|
|
|
|
# lat has token IDs as labels
|
|
# and word IDs as aux_labels.
|
|
lat = k2.invert(rescoring_lattice)
|
|
|
|
if lm_scale_list is None:
|
|
return lat
|
|
|
|
ans = dict()
|
|
#
|
|
# The following implements
|
|
# scores = (scores - lm_scores)/lm_scale + lm_scores
|
|
# = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
|
|
#
|
|
saved_am_scores = lat.scores - lat.lm_scores
|
|
for lm_scale in lm_scale_list:
|
|
am_scores = saved_am_scores / lm_scale
|
|
lat.scores = am_scores + lat.lm_scores
|
|
|
|
best_path = k2.shortest_path(lat, use_double_scores=True)
|
|
key = f"lm_scale_{lm_scale}"
|
|
ans[key] = best_path
|
|
return ans
|
|
|
|
|
|
def nbest_oracle(
|
|
lattice: k2.Fsa,
|
|
num_paths: int,
|
|
ref_texts: List[str],
|
|
lexicon: Lexicon,
|
|
scale: float = 1.0,
|
|
) -> Dict[str, List[List[int]]]:
|
|
"""Select the best hypothesis given a lattice and a reference transcript.
|
|
|
|
The basic idea is to extract n paths from the given lattice, unique them,
|
|
and select the one that has the minimum edit distance with the corresponding
|
|
reference transcript as the decoding output.
|
|
|
|
The decoding result returned from this function is the best result that
|
|
we can obtain using n-best decoding with all kinds of rescoring techniques.
|
|
|
|
Args:
|
|
lattice:
|
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
|
Note: We assume its aux_labels contain word IDs.
|
|
num_paths:
|
|
The size of `n` in n-best.
|
|
ref_texts:
|
|
A list of reference transcript. Each entry contains space(s)
|
|
separated words
|
|
lexicon:
|
|
It is used to convert word IDs to word symbols.
|
|
scale:
|
|
It's the scale applied to the lattice.scores. A smaller value
|
|
yields more unique paths.
|
|
Return:
|
|
Return a dict. Its key contains the information about the parameters
|
|
when calling this function, while its value contains the decoding output.
|
|
`len(ans_dict) == len(ref_texts)`
|
|
"""
|
|
path = _get_random_paths(
|
|
lattice=lattice,
|
|
num_paths=num_paths,
|
|
use_double_scores=True,
|
|
scale=scale,
|
|
)
|
|
|
|
word_seq = k2.index(lattice.aux_labels, path)
|
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
|
unique_word_seq, _, _ = k2.ragged.unique_sequences(
|
|
word_seq, need_num_repeats=False, need_new2old_indexes=False
|
|
)
|
|
unique_word_ids = k2.ragged.to_list(unique_word_seq)
|
|
assert len(unique_word_ids) == len(ref_texts)
|
|
# unique_word_ids[i] contains all hypotheses of the i-th utterance
|
|
|
|
results = []
|
|
for hyps, ref in zip(unique_word_ids, ref_texts):
|
|
# Note hyps is a list-of-list ints
|
|
# Each sublist contains a hypothesis
|
|
ref_words = ref.strip().split()
|
|
# CAUTION: We don't convert ref_words to ref_words_ids
|
|
# since there may exist OOV words in ref_words
|
|
best_hyp_words = None
|
|
min_error = float("inf")
|
|
for hyp_words in hyps:
|
|
hyp_words = [lexicon.word_table[i] for i in hyp_words]
|
|
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
|
|
if this_error < min_error:
|
|
min_error = this_error
|
|
best_hyp_words = hyp_words
|
|
results.append(best_hyp_words)
|
|
|
|
return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
|
|
|
|
|
|
def rescore_with_attention_decoder(
|
|
lattice: k2.Fsa,
|
|
num_paths: int,
|
|
model: nn.Module,
|
|
memory: torch.Tensor,
|
|
memory_key_padding_mask: Optional[torch.Tensor],
|
|
sos_id: int,
|
|
eos_id: int,
|
|
scale: float = 1.0,
|
|
ngram_lm_scale: Optional[float] = None,
|
|
attention_scale: Optional[float] = None,
|
|
) -> Dict[str, k2.Fsa]:
|
|
"""This function extracts n paths from the given lattice and uses
|
|
an attention decoder to rescore them. The path with the highest
|
|
score is used as the decoding output.
|
|
|
|
Args:
|
|
lattice:
|
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
|
num_paths:
|
|
Number of paths to extract from the given lattice for rescoring.
|
|
model:
|
|
A transformer model. See the class "Transformer" in
|
|
conformer_ctc/transformer.py for its interface.
|
|
memory:
|
|
The encoder memory of the given model. It is the output of
|
|
the last torch.nn.TransformerEncoder layer in the given model.
|
|
Its shape is `[T, N, C]`.
|
|
memory_key_padding_mask:
|
|
The padding mask for memory with shape [N, T].
|
|
sos_id:
|
|
The token ID for SOS.
|
|
eos_id:
|
|
The token ID for EOS.
|
|
scale:
|
|
It's the scale applied to the lattice.scores. A smaller value
|
|
yields more unique paths.
|
|
ngram_lm_scale:
|
|
Optional. It specifies the scale for n-gram LM scores.
|
|
attention_scale:
|
|
Optional. It specifies the scale for attention decoder scores.
|
|
Returns:
|
|
A dict of FsaVec, whose key contains a string
|
|
ngram_lm_scale_attention_scale and the value is the
|
|
best decoding path for each sequence in the lattice.
|
|
"""
|
|
# First, extract `num_paths` paths for each sequence.
|
|
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
|
path = _get_random_paths(
|
|
lattice=lattice,
|
|
num_paths=num_paths,
|
|
use_double_scores=True,
|
|
scale=scale,
|
|
)
|
|
|
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
|
# The last entry in each sublist is -1.
|
|
word_seq = k2.index(lattice.aux_labels, path)
|
|
|
|
# Remove epsilons and -1 from word_seq
|
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
|
|
|
# Remove paths that has identical word sequences.
|
|
#
|
|
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
|
|
# except that there are no repeated paths with the same word_seq
|
|
# within a sequence.
|
|
#
|
|
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
|
# multiplicities of each path.
|
|
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
|
|
#
|
|
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
|
# to the input path index.
|
|
# new2old.numel() == unique_word_seqs.tot_size(1)
|
|
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
|
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
|
)
|
|
|
|
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
|
|
|
# path_to_seq_map is a 1-D torch.Tensor.
|
|
# path_to_seq_map[i] is the seq to which the i-th path
|
|
# belongs.
|
|
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
|
|
|
# Remove the seq axis.
|
|
# Now unique_word_seq has only two axes [path][word]
|
|
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
|
|
|
# word_fsa is an FsaVec with axes [path][state][arc]
|
|
word_fsa = k2.linear_fsa(unique_word_seq)
|
|
|
|
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
|
|
|
am_scores, ngram_lm_scores = compute_am_and_lm_scores(
|
|
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
|
)
|
|
# Now we use the attention decoder to compute another
|
|
# score: attention_scores.
|
|
#
|
|
# To do that, we have to get the input and output for the attention
|
|
# decoder.
|
|
|
|
# CAUTION: The "tokens" attribute is set in the file
|
|
# local/compile_hlg.py
|
|
token_seq = k2.index(lattice.tokens, path)
|
|
|
|
# Remove epsilons and -1 from token_seq
|
|
token_seq = k2.ragged.remove_values_leq(token_seq, 0)
|
|
|
|
# Remove the seq axis.
|
|
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
|
|
|
token_seq, _ = k2.ragged.index(
|
|
token_seq, indexes=new2old, axis=0, need_value_indexes=False
|
|
)
|
|
|
|
# Now word in unique_word_seq has its corresponding token IDs.
|
|
token_ids = k2.ragged.to_list(token_seq)
|
|
|
|
num_word_seqs = new2old.numel()
|
|
|
|
path_to_seq_map_long = path_to_seq_map.to(torch.long)
|
|
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
|
|
|
if memory_key_padding_mask is not None:
|
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
|
0, path_to_seq_map_long
|
|
)
|
|
else:
|
|
expanded_memory_key_padding_mask = None
|
|
|
|
nll = model.decoder_nll(
|
|
memory=expanded_memory,
|
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
|
token_ids=token_ids,
|
|
sos_id=sos_id,
|
|
eos_id=eos_id,
|
|
)
|
|
assert nll.ndim == 2
|
|
assert nll.shape[0] == num_word_seqs
|
|
|
|
attention_scores = -nll.sum(dim=1)
|
|
assert attention_scores.ndim == 1
|
|
assert attention_scores.numel() == num_word_seqs
|
|
|
|
if ngram_lm_scale is None:
|
|
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
|
else:
|
|
ngram_lm_scale_list = [ngram_lm_scale]
|
|
|
|
if attention_scale is None:
|
|
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
|
else:
|
|
attention_scale_list = [attention_scale]
|
|
|
|
path_2axes = k2.ragged.remove_axis(path, 0)
|
|
|
|
ans = dict()
|
|
for n_scale in ngram_lm_scale_list:
|
|
for a_scale in attention_scale_list:
|
|
tot_scores = (
|
|
am_scores
|
|
+ n_scale * ngram_lm_scores
|
|
+ a_scale * attention_scores
|
|
)
|
|
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores)
|
|
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
|
|
|
best_path_indexes = k2.index(new2old, argmax_indexes)
|
|
|
|
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
|
|
best_path = k2.index(path_2axes, best_path_indexes)
|
|
|
|
# labels is a k2.RaggedInt with 2 axes [path][token_id]
|
|
# Note that it contains -1s.
|
|
labels = k2.index(lattice.labels.contiguous(), best_path)
|
|
|
|
labels = k2.ragged.remove_values_eq(labels, -1)
|
|
|
|
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
|
|
# aux_labels is also a k2.RaggedInt with 2 axes
|
|
aux_labels = k2.index(lattice.aux_labels, best_path.values())
|
|
|
|
best_path_fsa = k2.linear_fsa(labels)
|
|
best_path_fsa.aux_labels = aux_labels
|
|
|
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
|
ans[key] = best_path_fsa
|
|
return ans
|