From cabe8b625bf55ecabcfb7dc75359d4877c7235c8 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 4 Aug 2021 14:27:11 +0800 Subject: [PATCH] Copy the files related to multi round nbest rescoring from k2 & snowfall --- icefall/nbest.py | 264 +++++++++++++++++++++++++++++++++++++++++++++ icefall/utils.py | 73 +++++++++++++ test/test_nbest.py | 110 +++++++++++++++++++ 3 files changed, 447 insertions(+) create mode 100644 icefall/nbest.py create mode 100644 test/test_nbest.py diff --git a/icefall/nbest.py b/icefall/nbest.py new file mode 100644 index 000000000..1a5394673 --- /dev/null +++ b/icefall/nbest.py @@ -0,0 +1,264 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +# This file implements the ideas proposed by Daniel Povey. +# +# See https://github.com/k2-fsa/snowfall/issues/232 for more details +# +import logging +from typing import List + +import torch +import _k2 +import k2 + +# Note: We use `utterance` and `sequence` interchangeably in the comment + + +class Nbest(object): + ''' + An Nbest object contains two fields: + + (1) fsa, its type is k2.Fsa + (2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape) + + The field `fsa` is an FsaVec containing a vector of **linear** FSAs. + + The field `shape` has two axes [utt][path]. `shape.dim0()` contains + the number of utterances, which is also the number of rows in the + supervision_segments. `shape.tot_size(1)` contains the number + of paths, which is also the number of FSAs in `fsa`. + ''' + + def __init__(self, fsa: k2.Fsa, shape: _k2.RaggedShape) -> None: + assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}' + assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}' + + assert fsa.shape[0] == shape.tot_size(1), \ + f'{fsa.shape[0]} vs {shape.tot_size(1)}' + + self.fsa = fsa + self.shape = shape + + def __str__(self): + s = 'Nbest(' + s += f'num_seqs:{self.shape.dim0()}, ' + s += f'num_fsas:{self.fsa.shape[0]})' + return s + + def intersect(self, lats: k2.Fsa) -> 'Nbest': + '''Intersect this Nbest object with a lattice and get 1-best + path from the resulting FsaVec. + + Caution: + We assume FSAs in `self.fsa` don't have epsilon self-loops. + We also assume `self.fsa.labels` and `lats.labels` are token IDs. + + Args: + lats: + An FsaVec. It can be the return value of + :func:`whole_lattice_rescoring`. + Returns: + Return a new Nbest. This new Nbest shares the same shape with `self`, + while its `fsa` is the 1-best path from intersecting `self.fsa` and + `lats. + ''' + assert self.fsa.device == lats.device, \ + f'{self.fsa.device} vs {lats.device}' + assert len(lats.shape) == 3, f'{lats.shape}' + assert lats.arcs.dim0() == self.shape.dim0(), \ + f'{lats.arcs.dim0()} vs {self.shape.dim0()}' + + lats = k2.arc_sort(lats) # no-op if lats is already arc sorted + + fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa) + + path_to_seq_map = self.shape.row_ids(1) + + ans_lats = k2.intersect_device(a_fsas=lats, + b_fsas=fsas_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True) + + one_best = k2.shortest_path(ans_lats, use_double_scores=True) + + one_best = k2.remove_epsilon(one_best) + + return Nbest(fsa=one_best, shape=self.shape) + + def total_scores(self) -> _k2.RaggedFloat: + '''Get total scores of the FSAs in this Nbest. + + Note: + Since FSAs in Nbest are just linear FSAs, log-semirng and tropical + semiring produce the same total scores. + + Returns: + Return a ragged tensor with two axes [utt][path_scores]. + ''' + scores = self.fsa.get_tot_scores(use_double_scores=True, + log_semiring=False) + # We use single precision here since we only wrap k2.RaggedFloat. + # If k2.RaggedDouble is wrapped, we can use double precision here. + return _k2.RaggedFloat(self.shape, scores.float()) + + def top_k(self, k: int) -> 'Nbest': + '''Get a subset of paths in the Nbest. The resulting Nbest is regular + in that each sequence (i.e., utterance) has the same number of + paths (k). + + We select the top-k paths according to the total_scores of each path. + If a utterance has less than k paths, then its last path, after sorting + by tot_scores in descending order, is repeated so that each utterance + has exactly k paths. + + Args: + k: + Number of paths in each utterance. + Returns: + Return a new Nbest with a regular shape. + ''' + ragged_scores = self.total_scores() + + # indexes contains idx01's for self.shape + # ragged_scores.values()[indexes] is sorted + indexes = k2.ragged.sort_sublist(ragged_scores, + descending=True, + need_new2old_indexes=True) + + ragged_indexes = k2.RaggedInt(self.shape, indexes) + + padded_indexes = k2.ragged.pad(ragged_indexes, + mode='replicate', + value=-1) + assert torch.ge(padded_indexes, 0).all(), \ + 'Some utterances contain empty ' \ + f'n-best: {self.shape.row_splits(1)}' + + # Select the idx01's of top-k paths of each utterance + top_k_indexes = padded_indexes[:, :k].flatten().contiguous() + + top_k_fsas = k2.index_fsa(self.fsa, top_k_indexes) + + top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(), + dim1=k) + return Nbest(top_k_fsas, top_k_shape) + + +def whole_lattice_rescoring(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: + '''Rescore the 1st pass lattice with an LM. + + In general, the G in HLG used to obtain `lats` is a 3-gram LM. + This function replaces the 3-gram LM in `lats` with a 4-gram LM. + + Args: + lats: + The decoding lattice from the 1st pass. We assume it is the result + of intersecting HLG with the network output. + G_with_epsilon_loops: + An LM. It is usually a 4-gram LM with epsilon self-loops. + It should be arc sorted. + Returns: + Return a new lattice rescored with a given G. + ''' + assert len(lats.shape) == 3, f'{lats.shape}' + assert hasattr(lats, 'lm_scores') + assert G_with_epsilon_loops.shape == (1, None, None), \ + f'{G_with_epsilon_loops.shape}' + + device = lats.device + lats.scores = lats.scores - lats.lm_scores + # Now lats contains only acoustic scores + + # We will use lm_scores from the given G, so remove lats.lm_scores here + del lats.lm_scores + assert hasattr(lats, 'lm_scores') is False + + # inverted_lats has word IDs as labels. + # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt + # if lats.aux_labels is a ragged tensor + inverted_lats = k2.invert(lats) + num_seqs = lats.shape[0] + + b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) + + while True: + try: + rescoring_lats = k2.intersect_device(G_with_epsilon_loops, + inverted_lats, + b_to_a_map, + sorted_match_a=True) + break + except RuntimeError as e: + logging.info(f'Caught exception:\n{e}\n') + # Usually, this is an OOM exception. We reduce + # the size of the lattice and redo k2.intersect_device() + + # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here + # to avoid OOM. We may need to fine tune it. + logging.info(f'num_arcs before: {inverted_lats.num_arcs}') + inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True) + logging.info(f'num_arcs after: {inverted_lats.num_arcs}') + + rescoring_lats = k2.top_sort(k2.connect(rescoring_lats)) + + # inv_rescoring_lats has token IDs as labels + # and word IDs as aux_labels. + inv_rescoring_lats = k2.invert(rescoring_lats) + return inv_rescoring_lats + + +def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest: + '''Generate an n-best list from a lattice. + + Args: + lats: + The decoding lattice from the first pass after LM rescoring. + lats is an FsaVec. It can be the return value of + :func:`whole_lattice_rescoring` + num_paths: + Size of n for n-best list. CAUTION: After removing paths + that represent the same token sequences, the number of paths + in different sequences may not be equal. + Return: + Return an Nbest object. Note the returned FSAs don't have epsilon + self-loops. + ''' + assert len(lats.shape) == 3 + + # CAUTION: We use `phones` instead of `tokens` here because + # :func:`compile_HLG` uses `phones` + # + # Note: compile_HLG is from k2-fsa/snowfall + assert hasattr(lats, 'phones') + + assert not hasattr(lats, 'tokens') + lats.tokens = lats.phones + # we use tokens instead of phones in the following code + + # First, extract `num_paths` paths for each sequence. + # paths is a k2.RaggedInt with axes [seq][path][arc_pos] + paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + + # token_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains token IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + # Its axes are [seq][path][token_id] + token_seqs = k2.index(lats.tokens, paths) + + # Remove epsilons (0s) and -1 from token_seqs + token_seqs = k2.ragged.remove_values_leq(token_seqs, 0) + + # unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id]. + # But then number of pathsin each sequence may be different. + unique_token_seqs, _, _ = k2.ragged.unique_sequences( + token_seqs, need_num_repeats=False, need_new2old_indexes=False) + + seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) + + # Remove the seq axis. + # Now unique_token_seqs has only two axes [path][token_id] + unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) + + token_fsas = k2.linear_fsa(unique_token_seqs) + + return Nbest(fsa=token_fsas, shape=seq_to_path_shape) diff --git a/icefall/utils.py b/icefall/utils.py index 1f2cf95f3..359700a22 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -5,6 +5,7 @@ import subprocess from collections import defaultdict from contextlib import contextmanager from datetime import datetime +from nbest import Nbest from pathlib import Path from typing import Dict, Iterable, List, TextIO, Tuple, Union @@ -381,3 +382,75 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + + +def get_best_matching_stats(keys: Nbest, queries: Nbest, + max_order: int) -> torch.Tensor: + '''Get best matching stats on query positions. + + Args: + keys: + The nbest after doing second pass rescoring. + queries: + Another nbest before doing second pass rescoring. + max_order: + The maximum n-gram order to ever return by `k2.get_best_matching_stats` + + Returns: + A tensor with the shape of [queries.fsa.num_elements, 5], each row + contains the stats (init_score, mean, var, counts_out, ngram_order) + of the token in the correspodding position in queries. + ''' + assert keys.shape.dim0() == queries.shape.dim0(), \ + f'Utterances number in keys and queries should be equal : \ + {keys.shape.dim0()} vs {queries.shape.dim0()}' + + # keys_tokens_shape [utt][path][token] + keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape, + k2.ragged.remove_axis(keys.fsa.arcs.shape(), 1)) + # queries_tokens_shape [utt][path][token] + queries_tokens_shape = k2.ragged.compose_ragged_shapes(queries.shape, + k2.ragged.remove_axis(queries.fsa.arcs.shape(), 1)) + + keys_tokens = k2.RaggedInt(keys_tokens_shape, keys.fsa.labels.clone()) + queries_tokens = k2.RaggedInt(queries_tokens_shape, + queries.fsa.labels.clone()) + # tokens shape [utt][path][token] + tokens = k2.ragged.cat([keys_tokens, queries_tokens], axis=1) + + keys_token_num = keys.fsa.labels.size()[0] + queries_tokens_num = queries.fsa.labels.size()[0] + # counts on key positions are ones + keys_counts = k2.RaggedInt(keys_tokens_shape, + torch.ones(keys_token_num, + dtype=torch.int32)) + # counts on query positions are zeros + queries_counts = k2.RaggedInt(queries_tokens_shape, + torch.zeros(queries_tokens_num, + dtype=torch.int32)) + counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values() + + # scores on key positions are the scores inherted from nbest path + keys_scores = k2.RaggedFloat(keys_tokens_shape, keys.fsa.scores.clone()) + # scores on query positions MUST be zeros + queries_scores = k2.RaggedFloat(queries_tokens_shape, + torch.zeros(queries_tokens_num, + dtype=torch.float32)) + scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values() + + # we didn't remove -1 labels before + min_token = -1 + eos = -1 + max_token = torch.max(torch.max(keys.fsa.labels), + torch.max(queries.fsa.labels)) + mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores, + counts, eos, min_token, max_token, max_order) + + queries_init_scores = queries.fsa.scores.clone() + # only return the stats on query positions + masking = counts == 0 + # shape [queries_tokens_num, 5] + return torch.transpose(torch.stack((queries_init_scores, mean[masking], + var[masking], counts_out[masking], + ngram[masking])), 0, 1) + diff --git a/test/test_nbest.py b/test/test_nbest.py new file mode 100644 index 000000000..ce20a5ca2 --- /dev/null +++ b/test/test_nbest.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R nbest_test_py + +import unittest + +import k2 +import torch + +from icefall.nbest import Nbest + + +class TestNbest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.devices = [torch.device('cpu')] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device('cuda', 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device('cuda', 1)) + + def test_nbest_constructor(self): + fsa = k2.Fsa.from_str(''' + 0 1 -1 0.1 + 1 + ''') + + fsa_vec = k2.create_fsa_vec([fsa, fsa, fsa]) + shape = k2.RaggedShape('[[x x] [x]]') + Nbest(fsa_vec, shape) + + def test_top_k(self): + fsa0 = k2.Fsa.from_str(''' + 0 1 -1 0 + 1 + ''') + fsas = [fsa0.clone() for i in range(10)] + fsa_vec = k2.create_fsa_vec(fsas) + fsa_vec.scores = torch.tensor([3, 0, 1, 5, 4, 2, 8, 1, 9, 6], + dtype=torch.float) + # 0 1 2 3 4 5 6 7 8 9 + # [ [3 0] [1 5 4] [2 8 1 9 6] + shape = k2.RaggedShape('[ [x x] [x x x] [x x x x x] ]') + nbest = Nbest(fsa_vec, shape) + + # top_k: k is 1 + nbest1 = nbest.top_k(1) + expected_fsa = k2.create_fsa_vec([fsa_vec[0], fsa_vec[3], fsa_vec[8]]) + assert str(nbest1.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x] [x] [x] ]') + assert nbest1.shape == expected_shape + + # top_k: k is 2 + nbest2 = nbest.top_k(2) + expected_fsa = k2.create_fsa_vec([ + fsa_vec[0], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[8], + fsa_vec[6] + ]) + assert str(nbest2.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x x] [x x] [x x] ]') + assert nbest2.shape == expected_shape + + # top_k: k is 3 + nbest3 = nbest.top_k(3) + expected_fsa = k2.create_fsa_vec([ + fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4], + fsa_vec[2], fsa_vec[8], fsa_vec[6], fsa_vec[9] + ]) + assert str(nbest3.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x x x] [x x x] [x x x] ]') + assert nbest3.shape == expected_shape + + # top_k: k is 4 + nbest4 = nbest.top_k(4) + expected_fsa = k2.create_fsa_vec([ + fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[1], fsa_vec[3], + fsa_vec[4], fsa_vec[2], fsa_vec[2], fsa_vec[8], fsa_vec[6], + fsa_vec[9], fsa_vec[5] + ]) + assert str(nbest4.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x x x x] [x x x x] [x x x x] ]') + assert nbest4.shape == expected_shape + + +if __name__ == '__main__': + unittest.main()