Use an n-gram LM to rescore the lattice from fast_beam_search.

This commit is contained in:
Fangjun Kuang 2022-05-14 20:54:04 +08:00
parent 2d7096dfc6
commit 9ffc77a0f2
5 changed files with 570 additions and 49 deletions

View File

@ -1,8 +1,8 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# This script downloads the test-clean and test-other datasets # This script downloads the test-clean and test-other datasets
# of LibriSpeech and unzip them to the folder ~/tmp/download, # of LibriSpeech and unzips them to the folder ~/tmp/download,
# which is cached by GitHub actions for later runs. # which are cached by GitHub actions for later runs.
# #
# You will find directories ~/tmp/download/LibriSpeech after running # You will find directories ~/tmp/download/LibriSpeech after running
# this script. # this script.

View File

@ -10,6 +10,25 @@ During training, it selects either a batch from GigaSpeech with prob `giga_prob`
or a batch from LibriSpeech with prob `1 - giga_prob`. All utterances within or a batch from LibriSpeech with prob `1 - giga_prob`. All utterances within
a batch come from the same dataset. a batch come from the same dataset.
#### 2022-05-10
Using commit `TODO`.
The WERs are:
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|----------------------------------------|
| greedy search (max sym per frame 1) | 2.21 | 5.09 | --epoch 27 --avg 2 --max-duration 600 |
| greedy search (max sym per frame 1) | 2.25 | 5.02 | --epoch 27 --avg 12 --max-duration 600 |
| modified beam search | 2.19 | 5.03 | --epoch 25 --avg 6 --max-duration 600 |
| modified beam search | 2.23 | 4.94 | --epoch 27 --avg 10 --max-duration 600 |
| beam search | 2.16 | 4.95 | --epoch 25 --avg 7 --max-duration 600 |
| fast beam search | 2.21 | 4.96 | --epoch 27 --avg 10 --max-duration 600 |
| fast beam search | 2.19 | 4.97 | --epoch 27 --avg 12 --max-duration 600 |
#### 2022-04-29
Using commit `ac84220de91dee10c00e8f4223287f937b1930b6`. Using commit `ac84220de91dee10c00e8f4223287f937b1930b6`.
See <https://github.com/k2-fsa/icefall/pull/312>. See <https://github.com/k2-fsa/icefall/pull/312>.

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2 import k2
import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
@ -34,10 +35,11 @@ def fast_beam_search_one_best(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then A lattice is first obtained using fast beam search, and then
the shortest path within the lattice is used as the final output. the shortest path within the lattice is used as the final output.
Args: Args:
@ -56,6 +58,8 @@ def fast_beam_search_one_best(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -67,6 +71,7 @@ def fast_beam_search_one_best(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
@ -74,6 +79,85 @@ def fast_beam_search_one_best(
return hyps return hyps
def fast_beam_search_nbest(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then
we extract `num_paths` from the lattice using k2.random_path(),
unique them, compute the total score of each path by intersecting
it with the lattice, and output the path with the largest total score.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle( def fast_beam_search_nbest_oracle(
model: Transducer, model: Transducer,
decoding_graph: k2.Fsa, decoding_graph: k2.Fsa,
@ -86,10 +170,11 @@ def fast_beam_search_nbest_oracle(
ref_texts: List[List[int]], ref_texts: List[List[int]],
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then A lattice is first obtained using fast beam search, and then
we select `num_paths` linear paths from the lattice. The path we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript that has the minimum edit distance with the given reference transcript
is used as the output. is used as the output.
@ -125,6 +210,8 @@ def fast_beam_search_nbest_oracle(
nbest_scale: nbest_scale:
It's the scale applied to the lattice.scores. A smaller value It's the scale applied to the lattice.scores. A smaller value
yields more unique paths. yields more unique paths.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
@ -137,6 +224,7 @@ def fast_beam_search_nbest_oracle(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -169,6 +257,158 @@ def fast_beam_search_nbest_oracle(
return hyps return hyps
def fast_beam_search_with_nbest_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
am_scores = nbest.tot_scores()
# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0
token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)
assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]
word_ids_list: List[List[int]] = []
for words in word_list:
this_word_ids = []
for w in words:
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)
rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)
ans: Dict[str, List[List[int]]] = {}
for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
return ans
def fast_beam_search( def fast_beam_search(
model: Transducer, model: Transducer,
decoding_graph: k2.Fsa, decoding_graph: k2.Fsa,
@ -177,6 +417,7 @@ def fast_beam_search(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -196,6 +437,8 @@ def fast_beam_search(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns: Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned lattice. Note: When the input graph is a TrivialGraph, the returned
@ -244,7 +487,7 @@ def fast_beam_search(
project_input=False, project_input=False,
) )
logits = logits.squeeze(1).squeeze(1) logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) log_probs = (logits / temperature).log_softmax(dim=-1)
decoding_streams.advance(log_probs) decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
@ -587,6 +830,7 @@ def modified_beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -600,6 +844,8 @@ def modified_beam_search(
encoder_out before padding. encoder_out before padding.
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance. for the i-th utterance.
@ -683,7 +929,7 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = (logits / temperature).log_softmax(dim=-1)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
@ -847,6 +1093,7 @@ def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -860,6 +1107,8 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -936,7 +1185,7 @@ def beam_search(
) )
# TODO(fangjun): Scale the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze() log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,) # Now log_prob is (vocab_size,)

View File

@ -19,40 +19,67 @@
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search (one best)
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
(5) fast beam search (nbest)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest with n-gram LM rescoring)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_with_nbest_rescoring \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5 \
--lm-dir ./data/lm
""" """
@ -69,8 +96,10 @@ import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
fast_beam_search_with_nbest_rescoring,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -147,7 +176,9 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_with_nbest_rescoring
""", """,
) )
@ -168,7 +199,9 @@ def get_parser():
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""", fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring
""",
) )
parser.add_argument( parser.add_argument(
@ -176,7 +209,9 @@ def get_parser():
type=int, type=int,
default=4, default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""", fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring
""",
) )
parser.add_argument( parser.add_argument(
@ -184,7 +219,9 @@ def get_parser():
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""", fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring
""",
) )
parser.add_argument( parser.add_argument(
@ -194,6 +231,7 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
@ -207,7 +245,8 @@ def get_parser():
type=int, type=int,
default=100, default=100,
help="""Number of paths for computed nbest oracle WER help="""Number of paths for computed nbest oracle WER
when the decoding method is fast_beam_search_nbest_oracle. when the decoding method is fast_beam_search_nbest_oracle,
fast_beam_search_nbest, or fast_beam_search_with_nbest_rescoring.
""", """,
) )
@ -216,9 +255,40 @@ def get_parser():
type=float, type=float,
default=0.5, default=0.5,
help="""Scale applied to lattice scores when computing nbest paths. help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding_method is fast_beam_search_nbest_oracle. Used only when the decoding_method is fast_beam_search_nbest_oracle,
fast_beam_search_nbest, or fast_beam_search_with_nbest_rescoring.
""", """,
) )
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Softmax temperature.
The output of the model is (logits / temperature).log_softmax().
""",
)
parser.add_argument(
"--lm-dir",
type=Path,
default=Path("./data/lm"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--words-txt",
type=Path,
default=Path("./data/lang_bpe_500/words.txt"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It is the word table.
""",
)
return parser return parser
@ -228,6 +298,8 @@ def decode_one_batch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
word_table: Optional[k2.SymbolTable] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -252,8 +324,17 @@ def decode_one_batch(
for the format of the `batch`. for the format of the `batch`.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is only when decoding method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search or fast_beam_search_nbest_oracle. fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It an FsaVec containing an acceptor.
word_table:
Optional. Used only when decoding method is
fast_beam_search_with_nbest_rescoring. It is the word symbol table
containing mappings between words and IDs.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -282,6 +363,22 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
temperature=params.temperature,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -297,9 +394,30 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
ngram_lm_scale_list = [-0.3, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05]
hyp_tokens = fast_beam_search_with_nbest_rescoring(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
ngram_lm_scale_list=ngram_lm_scale_list,
num_paths=params.num_paths,
G=G,
sp=sp,
word_table=word_table,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -317,6 +435,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -338,6 +457,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
temperature=params.temperature,
) )
else: else:
raise ValueError( raise ValueError(
@ -352,7 +472,19 @@ def decode_one_batch(
( (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}_"
f"temperature_{params.temperature}"
): hyps
}
elif params.decoding_method == "fast_beam_search_nbest":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}_"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
@ -362,11 +494,31 @@ def decode_one_batch(
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_" f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_" f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}" f"nbest_scale_{params.nbest_scale}_"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
prefix = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}_"
f"temperature_{params.temperature}_"
)
ans: Dict[str, List[List[str]]] = {}
for key, hyp in hyp_tokens.items():
t: List[str] = sp.decode(hyp)
ans[prefix + key] = [s.split() for s in t]
return ans
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {
(
f"beam_size_{params.beam_size}_"
f"temperature_{params.temperature}"
): hyps
}
def decode_dataset( def decode_dataset(
@ -375,6 +527,8 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
word_table: Optional[k2.SymbolTable] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -389,7 +543,17 @@ def decode_dataset(
The BPE model. The BPE model.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when decoding method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It an FsaVec containing an acceptor.
word_table:
Optional. Used only when decoding method is
fast_beam_search_with_nbest_rescoring. It is the word symbol table
containing mappings between words and IDs.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -419,6 +583,8 @@ def decode_dataset(
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
G=G,
word_table=word_table,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -438,6 +604,7 @@ def decode_dataset(
logging.info( logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}" f"batch {batch_str}, cuts processed until now is {num_cuts}"
) )
return results return results
@ -485,6 +652,68 @@ def save_results(
logging.info(s) logging.info(s)
def load_ngram_LM(
lm_dir: Path, word_table: k2.SymbolTable, device: torch.device
) -> k2.Fsa:
"""Read a ngram model from the given directory.
Args:
lm_dir:
It should contain either G_4_gram.pt or G_4_gram.fst.txt
word_table:
The word table mapping words to IDs and vice versa.
device:
The resulting FSA will be moved to this device.
Returns:
Return an FsaVec containing a single acceptor.
"""
lm_dir = Path(lm_dir)
assert lm_dir.is_dir(), f"{lm_dir} does not exist"
pt_file = lm_dir / "G_4_gram.pt"
if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device)
G = k2.Fsa.from_dict(d)
return G
txt_file = lm_dir / "G_4_gram.fst.txt"
assert txt_file.is_file(), f"{txt_file} does not exist"
logging.info(f"Loading {txt_file}")
logging.warning("It may take 8 minutes (Will be cached for later use).")
with open(txt_file) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# Now G is an acceptor
first_word_disambig_id = word_table["#0"]
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
logging.info(f"Saving to {pt_file} for later use")
torch.save(G.as_dict(), pt_file)
return G
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -499,7 +728,9 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"fast_beam_search_with_nbest_rescoring",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -513,19 +744,27 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif params.decoding_method == "fast_beam_search_nbest_oracle": params.suffix += f"-temperature-{params.temperature}"
elif params.decoding_method in (
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle",
"fast_beam_search_with_nbest_rescoring",
):
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-temperature-{params.temperature}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
) )
params.suffix += f"-temperature-{params.temperature}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-temperature-{params.temperature}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -585,12 +824,26 @@ def main():
if params.decoding_method in ( if params.decoding_method in (
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"fast_beam_search_with_nbest_rescoring",
): ):
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: else:
decoding_graph = None decoding_graph = None
if params.decoding_method == "fast_beam_search_with_nbest_rescoring":
logging.info(f"Loading word symbol table from {params.words_txt}")
word_table = k2.SymbolTable.from_file(params.words_txt)
G = load_ngram_LM(
lm_dir=params.lm_dir,
word_table=word_table,
device=device,
)
else:
word_table = None
G = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -613,6 +866,8 @@ def main():
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
G=G,
word_table=word_table,
) )
save_results( save_results(

View File

@ -308,9 +308,7 @@ class Nbest(object):
del word_fsa.aux_labels del word_fsa.aux_labels
word_fsa.scores.zero_() word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
word_fsa
)
path_to_utt_map = self.shape.row_ids(1) path_to_utt_map = self.shape.row_ids(1)
@ -609,7 +607,7 @@ def rescore_with_n_best_list(
num_paths: num_paths:
Size of nbest list. Size of nbest list.
lm_scale_list: lm_scale_list:
A list of float representing LM score scales. A list of floats representing LM score scales.
nbest_scale: nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``. using ``k2.random_paths``.