Refactoring.

This commit is contained in:
Fangjun Kuang 2021-09-18 16:14:20 +08:00
parent 38cfd06ccb
commit 8623983bb7
2 changed files with 139 additions and 950 deletions

View File

@ -1,913 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple, Union
import k2
import kaldialign
import torch
import torch.nn as nn
def _get_random_paths(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
):
"""
Args:
lattice:
The decoding lattice, returned by :func:`get_lattice`.
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are remove dand only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return a k2.RaggedInt with 3 axes [seq][path][arc_pos]
"""
saved_scores = lattice.scores.clone()
lattice.scores *= scale
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
lattice.scores = saved_scores
return path
def _intersect_device(
a_fsas: k2.Fsa,
b_fsas: k2.Fsa,
b_to_a_map: torch.Tensor,
sorted_match_a: bool,
batch_size: int = 50,
) -> k2.Fsa:
"""This is a wrapper of k2.intersect_device and its purpose is to split
b_fsas into several batches and process each batch separately to avoid
CUDA OOM error.
The arguments and return value of this function are the same as
k2.intersect_device.
"""
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
return k2.intersect_device(
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
)
num_batches = (num_fsas + batch_size - 1) // batch_size
splits = []
for i in range(num_batches):
start = i * batch_size
end = min(start + batch_size, num_fsas)
splits.append((start, end))
ans = []
for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map)
fsas = k2.index_fsa(b_fsas, indexes)
b_to_a = k2.index_select(b_to_a_map, indexes)
path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
)
ans.append(path_lattice)
return k2.cat(ans)
def get_lattice(
nnet_output: torch.Tensor,
HLG: k2.Fsa,
supervision_segments: torch.Tensor,
search_beam: float,
output_beam: float,
min_active_states: int,
max_active_states: int,
subsampling_factor: int = 1,
) -> k2.Fsa:
"""Get the decoding lattice from a decoding graph and neural
network output.
Args:
nnet_output:
It is the output of a neural model of shape `[N, T, C]`.
HLG:
An Fsa, the decoding graph. See also `compile_HLG.py`.
supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0
is the `sequence_index` indicating which sequence this segment
comes from; column 1 specifies the `start_frame` of this segment
within the sequence; column 2 contains the `duration` of this
segment.
search_beam:
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be modified by
`min_active_states` and `max_active_states`.
output_beam:
Beam to prune output, similar to lattice-beam in Kaldi. Relative
to best path of output.
min_active_states:
Minimum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to have fewer than this number active.
Set it to zero if there is no constraint.
max_active_states:
Maximum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to exceed that but may not always succeed.
You can use a very large number if no constraint is needed.
subsampling_factor:
The subsampling factor of the model.
Returns:
A lattice containing the decoding result.
"""
dense_fsa_vec = k2.DenseFsaVec(
nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
)
lattice = k2.intersect_dense_pruned(
HLG,
dense_fsa_vec,
search_beam=search_beam,
output_beam=output_beam,
min_active_states=min_active_states,
max_active_states=max_active_states,
)
return lattice
def one_best_decoding(
lattice: k2.Fsa, use_double_scores: bool = True
) -> k2.Fsa:
"""Get the best path from a lattice.
Args:
lattice:
The decoding lattice returned by :func:`get_lattice`.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
Return:
An FsaVec containing linear paths.
"""
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
return best_path
def nbest_decoding(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
The basic idea is to first extra n-best paths from the given lattice,
build a word seqs from these paths, and compute the total scores
of these sequences in the log-semiring. The one with the max score
is used as the decoding output.
Caution:
Don't be confused by `best` in the name `n-best`. Paths are selected
randomly, not by ranking their scores.
Args:
lattice:
The decoding lattice, returned by :func:`get_lattice`.
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are remove dand only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
An FsaVec containing linear FSAs.
"""
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
scale=scale,
)
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove 0 (epsilon) and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
# Remove sequences with identical word sequences.
#
# k2.ragged.unique_sequences will reorder paths within a seq.
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, _, new2old = word_seq.unique(
need_num_repeats=False, need_new2old_indexes=True
)
# Note: unique_word_seq still has the same axes as word_seq
seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path belongs
path_to_seq_map = seq_to_path_shape.row_ids(1)
# Remove the seq axis.
# Now unique_word_seq has only two axes [path][word]
unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
# add epsilon self loops since we will use
# k2.intersect_device, which treats epsilon as a normal symbol
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True,
)
# path_lat has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
argmax_indexes = ragged_tot_scores.argmax()
# Since we invoked `k2.ragged.unique_sequences`, which reorders
# the index from `path`, we use `new2old` here to convert argmax_indexes
# to the indexes into `path`.
#
# Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index_select(new2old, argmax_indexes)
path_2axes = path.remove_axis(0)
# best_path is a k2.RaggedTensor with 2 axes [path][arc_pos]
best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedTensor with 2 axes [path][token_id]
# Note that it contains -1s.
labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
# aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.values, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels
return best_path_fsa
def compute_am_and_lm_scores(
lattice: k2.Fsa,
word_fsa_with_epsilon_loops: k2.Fsa,
path_to_seq_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute AM scores of n-best lists (represented as word_fsas).
Args:
lattice:
An FsaVec, e.g., the return value of :func:`get_lattice`
It must have the attribute `lm_scores`.
word_fsa_with_epsilon_loops:
An FsaVec representing an n-best list. Note that it has been processed
by `k2.add_epsilon_self_loops`.
path_to_seq_map:
A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
Returns:
Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
"""
assert len(lattice.shape) == 3
assert hasattr(lattice, "lm_scores")
# k2.compose() currently does not support b_to_a_map. To void
# replicating `lats`, we use k2.intersect_device here.
#
# lattice has token IDs as `labels` and word IDs as aux_labels, so we
# need to invert it here.
inv_lattice = k2.invert(lattice)
# Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
# and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
# Remove its `aux_labels` since it is not needed in the
# following computation
del inv_lattice.aux_labels
inv_lattice = k2.arc_sort(inv_lattice)
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True,
)
path_lattice = k2.top_sort(k2.connect(path_lattice))
# The `scores` of every arc consists of `am_scores` and `lm_scores`
path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
am_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=False
)
path_lattice.scores = path_lattice.lm_scores
lm_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=False
)
return am_scores.to(torch.float32), lm_scores.to(torch.float32)
def rescore_with_n_best_list(
lattice: k2.Fsa,
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
scale: float = 1.0,
) -> Dict[str, k2.Fsa]:
"""Decode using n-best list with LM rescoring.
`lattice` is a decoding lattice with 3 axes. This function first
extracts `num_paths` paths from `lattice` for each sequence using
`k2.random_paths`. The `am_scores` of these paths are computed.
For each path, its `lm_scores` is computed using `G` (which is an LM).
The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
The path with the largest `tot_scores` within a sequence is used
as the decoding output.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
G:
An FsaVec representing the language model (LM). Note that it
is an FsaVec, but it contains only one Fsa.
num_paths:
It is the size `n` in `n-best` list.
lm_scale_list:
A list containing lm_scale values.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each sequence in the lattice.
"""
device = lattice.device
assert len(lattice.shape) == 3
assert hasattr(lattice, "aux_labels")
assert hasattr(lattice, "lm_scores")
assert G.shape == (1, None, None)
assert G.device == device
assert hasattr(G, "aux_labels") is False
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
# Remove paths that has identical word sequences.
#
# unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq
# within a sequence.
#
# num_repeats is also a k2.RaggedTensor with 2 axes containing the
# multiplicities of each path.
# num_repeats.numel() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, num_repeats, new2old = word_seq.unique(
need_num_repeats=True, need_new2old_indexes=True
)
seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path
# belongs.
path_to_seq_map = seq_to_path_shape.row_ids(1)
# Remove the seq axis.
# Now unique_word_seq has only two axes [path][word]
unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
am_scores, _ = compute_am_and_lm_scores(
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
)
# Now compute lm_scores
b_to_a_map = torch.zeros_like(path_to_seq_map)
lm_path_lattice = _intersect_device(
G,
word_fsa_with_epsilon_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
)
lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice))
lm_scores = lm_path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=False
)
path_2axes = path.remove_axis(0)
ans = dict()
for lm_scale in lm_scale_list:
tot_scores = am_scores / lm_scale + lm_scores
# Remember that we used `k2.RaggedTensor.unique` to remove repeated
# paths to avoid redundant computation in `k2.intersect_device`.
# Now we use `num_repeats` to correct the scores for each path.
#
# NOTE(fangjun): It is commented out as it leads to a worse WER
# tot_scores = tot_scores * num_repeats.values()
ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
argmax_indexes = ragged_tot_scores.argmax()
# Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index_select(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedTensor with 2 axes [path][phone_id]
# Note that it contains -1s.
labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so
# aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.values, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels
key = f"lm_scale_{lm_scale}"
ans[key] = best_path_fsa
return ans
def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
"""Use whole lattice to rescore.
Args:
lattice:
An FsaVec It can be the return value of :func:`get_lattice`.
G_with_epsilon_loops:
An FsaVec representing the language model (LM). Note that it
is an FsaVec, but it contains only one Fsa.
lm_scale_list:
A list containing lm_scale values or None.
Returns:
If lm_scale_list is not None, return a dict of FsaVec, whose key
is a lm_scale and the value represents the best decoding path for
each sequence in the lattice.
If lm_scale_list is not None, return a lattice that is rescored
with the given LM.
"""
assert len(lattice.shape) == 3
assert hasattr(lattice, "lm_scores")
assert G_with_epsilon_loops.shape == (1, None, None)
device = lattice.device
lattice.scores = lattice.scores - lattice.lm_scores
# We will use lm_scores from G, so remove lats.lm_scores here
del lattice.lm_scores
assert hasattr(lattice, "lm_scores") is False
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
while True:
try:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
inv_lattice,
b_to_a_map,
sorted_match_a=True,
)
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
# NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here
# to avoid OOM. We may need to fine tune it.
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True)
logging.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
# lat has token IDs as labels
# and word IDs as aux_labels.
lat = k2.invert(rescoring_lattice)
if lm_scale_list is None:
return lat
ans = dict()
#
# The following implements
# scores = (scores - lm_scores)/lm_scale + lm_scores
# = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
#
saved_am_scores = lat.scores - lat.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lat.scores = am_scores + lat.lm_scores
best_path = k2.shortest_path(lat, use_double_scores=True)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
def nbest_oracle(
lattice: k2.Fsa,
num_paths: int,
ref_texts: List[str],
word_table: k2.SymbolTable,
scale: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
The basic idea is to extract n paths from the given lattice, unique them,
and select the one that has the minimum edit distance with the corresponding
reference transcript as the decoding output.
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
Note: We assume its aux_labels contain word IDs.
num_paths:
The size of `n` in n-best.
ref_texts:
A list of reference transcript. Each entry contains space(s)
separated words
word_table:
It is the word symbol table.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Return:
Return a dict. Its key contains the information about the parameters
when calling this function, while its value contains the decoding output.
`len(ans_dict) == len(ref_texts)`
"""
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique(
need_num_repeats=False, need_new2old_indexes=False
)
unique_word_ids = unique_word_seq.tolist()
assert len(unique_word_ids) == len(ref_texts)
# unique_word_ids[i] contains all hypotheses of the i-th utterance
results = []
for hyps, ref in zip(unique_word_ids, ref_texts):
# Note hyps is a list-of-list ints
# Each sublist contains a hypothesis
ref_words = ref.strip().split()
# CAUTION: We don't convert ref_words to ref_words_ids
# since there may exist OOV words in ref_words
best_hyp_words = None
min_error = float("inf")
for hyp_words in hyps:
hyp_words = [word_table[i] for i in hyp_words]
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
if this_error < min_error:
min_error = this_error
best_hyp_words = hyp_words
results.append(best_hyp_words)
return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
score is used as the decoding output.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
# Remove paths that has identical word sequences.
#
# unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq
# within a sequence.
#
# num_repeats is also a k2.RaggedTensor with 2 axes containing the
# multiplicities of each path.
# num_repeats.numel() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seq.tot_size(1)
unique_word_seq, num_repeats, new2old = word_seq.unique(
need_num_repeats=True, need_new2old_indexes=True
)
seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path
# belongs.
path_to_seq_map = seq_to_path_shape.row_ids(1)
# Remove the seq axis.
# Now unique_word_seq has only two axes [path][word]
unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
am_scores, ngram_lm_scores = compute_am_and_lm_scores(
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
)
# Now we use the attention decoder to compute another
# score: attention_scores.
#
# To do that, we have to get the input and output for the attention
# decoder.
# CAUTION: The "tokens" attribute is set in the file
# local/compile_hlg.py
if isinstance(lattice.tokens, torch.Tensor):
token_seq = k2.ragged.index(lattice.tokens, path)
else:
token_seq = lattice.tokens.index(path)
token_seq = token_seq.remove_axis(token_seq.num_axes - 2)
# Remove epsilons and -1 from token_seq
token_seq = token_seq.remove_values_leq(0)
# Remove the seq axis.
token_seq = token_seq.remove_axis(0)
token_seq, _ = token_seq.index(
indexes=new2old, axis=0, need_value_indexes=False
)
# Now word in unique_word_seq has its corresponding token IDs.
token_ids = token_seq.tolist()
num_word_seqs = new2old.numel()
path_to_seq_map_long = path_to_seq_map.to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
if memory_key_padding_mask is not None:
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
else:
expanded_memory_key_padding_mask = None
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs
attention_scores = -nll.sum(dim=1)
assert attention_scores.ndim == 1
assert attention_scores.numel() == num_word_seqs
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
path_2axes = path.remove_axis(0)
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores
+ n_scale * ngram_lm_scores
+ a_scale * attention_scores
)
ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
argmax_indexes = ragged_tot_scores.argmax()
best_path_indexes = k2.index_select(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedTensor with 2 axes [path][token_id]
# Note that it contains -1s.
labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = labels.remove_values_eq(-1)
if isinstance(lattice.aux_labels, torch.Tensor):
aux_labels = k2.index_select(
lattice.aux_labels, best_path.values
)
else:
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.values, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path_fsa
return ans

View File

@ -14,9 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: This file is a refactor of decode.py
# We will delete decode.py and rename this file to decode.py
import logging
from typing import Dict, List, Optional, Union
@ -38,7 +35,7 @@ def _intersect_device(
CUDA OOM error.
The arguments and return value of this function are the same as
k2.intersect_device.
:func:`k2.intersect_device`.
"""
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
@ -107,7 +104,9 @@ class Nbest(object):
An Nbest object contains two fields:
(1) fsa. It is an FsaVec containing a vector of **linear** FSAs.
Its axes are [path][state][arc]
(2) shape. Its type is :class:`k2.RaggedShape`.
Its axes are [utt][path]
The field `shape` has two axes [utt][path]. `shape.dim0` contains
the number of utterances, which is also the number of rows in the
@ -116,11 +115,19 @@ class Nbest(object):
Caution:
Don't be confused by the name `Nbest`. The best in the name `Nbest`
has nothing to do with the `best scores`. The important part is
has nothing to do with `best scores`. The important part is
`N` in `Nbest`, not `best`.
"""
def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None:
"""
Args:
fsa:
An FsaVec with axes [path][state][arc]. It is expected to contain
a list of **linear** FSAs.
shape:
A ragged shape with two axes [utt][path].
"""
assert len(fsa.shape) == 3, f"fsa.shape: {fsa.shape}"
assert shape.num_axes == 2, f"num_axes: {shape.num_axes}"
@ -135,8 +142,8 @@ class Nbest(object):
def __str__(self):
s = "Nbest("
s += f"num_seqs:{self.shape.dim0}, "
s += f"num_fsas:{self.fsa.shape[0]})"
s += f"Number of utterances:{self.shape.dim0}, "
s += f"Number of Paths:{self.fsa.shape[0]})"
return s
@staticmethod
@ -148,6 +155,11 @@ class Nbest(object):
) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
Each sampled path is a linear FSA.
We assume `lattice.labels` contains token IDs and `lattice.aux_labels`
contains word IDs.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
@ -159,13 +171,15 @@ class Nbest(object):
False to use single precision.
scale:
Scale `lattice.score` before passing it to :func:`k2.random_paths`.
A smaller value leads to more unique paths with the risk being not
A smaller value leads to more unique paths at the risk of being not
to sample the path with the best score.
Returns:
Return an Nbest instance.
"""
saved_scores = lattice.scores.clone()
lattice.scores *= lattice_score_scale
# path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos
# It has three axes [utt][path][arc_pos]
path = k2.random_paths(
lattice, num_paths=num_paths, use_double_scores=use_double_scores
)
@ -174,6 +188,7 @@ class Nbest(object):
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# It axes is [utt][path][word_id]
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
@ -224,28 +239,28 @@ class Nbest(object):
fsa = k2.linear_fsa(labels)
fsa.aux_labels = aux_labels
# Caution: fsa.scores are all 0s.
# `fsa` has only one extra attribute: aux_labels.
return Nbest(fsa=fsa, shape=utt_to_path_shape)
def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
"""Intersect this Nbest object with a lattice and get 1-best
path from the resulting FsaVec.
"""Intersect this Nbest object with a lattice, get 1-best
path from the resulting FsaVec, and return a new Nbest object.
The purpose of this function is to attach scores to an Nbest.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then
we assume its `labels` are token IDs and `aux_labels` are word IDs.
If it has only `labels`, we assume it `labels` are word IDs.
If it has only `labels`, we assume its `labels` are word IDs.
use_double_scores:
True to use double precision when computing shortest path.
False to use single precision.
Returns:
Return a new Nbest. This new Nbest shares the same shape with `self`,
while its `fsa` is the 1-best path from intersecting `self.fsa` and
`lattice`.
`lattice`. Also, its `fsa` has non-zero scores and inherits attributes
for `lattice`.
"""
# Note: We view each linear FSA as a word sequence
# and we use the passed lattice to give each word sequence a score.
@ -308,7 +323,8 @@ class Nbest(object):
an utterance).
Hint:
`self.fsa.scores` contains two parts: am scores and lm scores.
`self.fsa.scores` contains two parts: acoustic scores (AM scores)
and n-gram language model scores (LM scores).
Caution:
We require that ``self.fsa`` has an attribute ``lm_scores``.
@ -334,7 +350,8 @@ class Nbest(object):
an utterance).
Hint:
`self.fsa.scores` contains two parts: am scores and lm scores.
`self.fsa.scores` contains two parts: acoustic scores (AM scores)
and n-gram language model scores (LM scores).
Caution:
We require that ``self.fsa`` has an attribute ``lm_scores``.
@ -348,9 +365,6 @@ class Nbest(object):
# The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.lm_scores
# Caution: self.fsa.lm_scores is per arc
# while lm_scores in the following is per path
#
lm_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
)
@ -359,17 +373,16 @@ class Nbest(object):
return k2.RaggedTensor(self.shape, lm_scores)
def tot_scores(self) -> k2.RaggedTensor:
"""Get total scores of the FSAs in this Nbest.
"""Get total scores of FSAs in this Nbest.
Note:
Since FSAs in Nbest are just linear FSAs, log-semirng and tropical
semiring produce the same total scores.
Since FSAs in Nbest are just linear FSAs, log-semiring
and tropical semiring produce the same total scores.
Returns:
Return a ragged tensor with two axes [utt][path_scores].
Its dtype is torch.float64.
"""
# Use single precision since there are only additions.
scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
)
@ -382,6 +395,25 @@ class Nbest(object):
return k2.Fsa.from_fsas(word_levenshtein_graphs)
def one_best_decoding(
lattice: k2.Fsa,
use_double_scores: bool = True,
) -> k2.Fsa:
"""Get the best path from a lattice.
Args:
lattice:
The decoding lattice returned by :func:`get_lattice`.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
Return:
An FsaVec containing linear paths.
"""
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
return best_path
def nbest_decoding(
lattice: k2.Fsa,
num_paths: int,
@ -390,7 +422,7 @@ def nbest_decoding(
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
The basic idea is to first extra `num_paths` paths from the given lattice,
The basic idea is to first extract `num_paths` paths from the given lattice,
build a word sequence from these paths, and compute the total scores
of the word sequence in the tropical semiring. The one with the max score
is used as the decoding output.
@ -399,6 +431,10 @@ def nbest_decoding(
Don't be confused by `best` in the name `n-best`. Paths are selected
**randomly**, not by ranking their scores.
Hint:
This decoding method is for demonstration only and it does
not produce a lower WER than :func:`one_best_decoding`.
Args:
lattice:
The decoding lattice, e.g., can be the return value of
@ -411,10 +447,10 @@ def nbest_decoding(
True to use double precision floating point in the computation.
False to use single precision.
lattice_score_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
It's the scale applied to the `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
Returns:
An FsaVec containing linear FSAs. It axes are [utt][state][arc].
An FsaVec containing **linear** FSAs. It axes are [utt][state][arc].
"""
nbest = Nbest.from_lattice(
lattice=lattice,
@ -446,9 +482,9 @@ def nbest_oracle(
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
The basic idea is to extract n paths from the given lattice, unique them,
and select the one that has the minimum edit distance with the corresponding
reference transcript as the decoding output.
The basic idea is to extract `num_paths` paths from the given lattice,
unique them, and select the one that has the minimum edit distance with
the corresponding reference transcript as the decoding output.
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
@ -458,7 +494,7 @@ def nbest_oracle(
Args:
lattice:
An FsaVec with axes [utt][state][arc].
Note: We assume its aux_labels contain word IDs.
Note: We assume its `aux_labels` contains word IDs.
num_paths:
The size of `n` in n-best.
ref_texts:
@ -500,6 +536,7 @@ def nbest_oracle(
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list]
refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device)
@ -536,8 +573,8 @@ def rescore_with_n_best_list(
lattice_score_scale: float = 1.0,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""Rescore a nbest list with an n-gram LM.
The path with a maximum score is used as the decoding output.
"""Rescore an n-best list with an n-gram LM.
The path with the maximum score is used as the decoding output.
Args:
lattice:
@ -605,7 +642,38 @@ def rescore_with_whole_lattice(
lm_scale_list: Optional[List[float]] = None,
use_double_scores: bool = True,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
# This is not an Nbest based coding method
"""Intersect the lattice with an n-gram LM and use shortest path
to decode.
The input lattice is obtained by intersecting `HLG` with
a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM.
The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider
this function as a second pass decoding. In the first pass decoding, we
use a small G, while we use a larger G in the second pass decoding.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. Its `aux_lables` are word IDs.
It must have an attribute `lm_scores`.
G_with_epsilon_loops:
An FsaVec containing only a single FSA. It contains epsilon self-loops.
It is an acceptor and its labels are word IDs.
lm_scale_list:
Optional. If none, return the intersection of `lattice` and
`G_with_epsilon_loops`.
If not None, it contains a list of values to scale LM scores.
For each scale, there is a corresponding decoding result contained in
the resulting dict.
use_double_scores:
True to use double precision in the computation.
False to use single precision.
Returns:
If `lm_scale_list` is None, return a new lattice which is the intersection
result of `lattice` and `G_with_epsilon_loops`.
Otherwise, return a dict whose key is an entry in `lm_scale_list` and the
value is the decoding result (i.e., an FsaVec containing linear FSAs).
"""
# Nbest is not used in this function
assert hasattr(lattice, "lm_scores")
assert G_with_epsilon_loops.shape == (1, None, None)
@ -619,7 +687,7 @@ def rescore_with_whole_lattice(
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
# Its aux_labels are token IDs
# Its `aux_labels` is token IDs
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
@ -668,7 +736,7 @@ def rescore_with_whole_lattice(
lat.scores = am_scores + lat.lm_scores
best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}_yy"
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
@ -686,6 +754,40 @@ def rescore_with_attention_decoder(
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
memory_key_padding_mask:
The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
lattice_score_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,