Add fast_beam_search_nbest.

This commit is contained in:
Fangjun Kuang 2022-06-14 17:04:04 +08:00
parent 53f38c01d2
commit 1bf2e17437
8 changed files with 633 additions and 89 deletions

View File

@ -75,6 +75,86 @@ def fast_beam_search_one_best(
return hyps
def fast_beam_search_nbest(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,

View File

@ -82,6 +82,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
@ -250,6 +251,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search and
--use-LG is True.
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search and
--use-LG is True.
""",
)
return parser
@ -307,21 +328,32 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.use_LG:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
else:
if not params.use_LG:
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1

View File

@ -37,7 +37,7 @@ def fast_beam_search_one_best(
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
A lattice is first obtained using fast beam search, and then
the shortest path within the lattice is used as the final output.
Args:
@ -74,6 +74,86 @@ def fast_beam_search_one_best(
return hyps
def fast_beam_search_nbest(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
@ -89,7 +169,7 @@ def fast_beam_search_nbest_oracle(
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
A lattice is first obtained using fast beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.

View File

@ -43,7 +43,7 @@ Usage:
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
@ -53,6 +53,32 @@ Usage:
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
"""
@ -69,6 +95,8 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
@ -145,6 +173,8 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
""",
)
@ -164,7 +194,9 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -172,7 +204,8 @@ def get_parser():
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -180,7 +213,8 @@ def get_parser():
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -198,6 +232,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
return parser
@ -231,7 +285,8 @@ def decode_one_batch(
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -263,6 +318,35 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -318,6 +402,16 @@ def decode_one_batch(
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search_nbest" in params.decoding_method:
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -342,7 +436,8 @@ def decode_dataset(
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -360,7 +455,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -452,6 +547,8 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -461,10 +558,16 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "fast_beam_search_nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -528,7 +631,7 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None

View File

@ -19,40 +19,66 @@
Usage:
(1) greedy search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search (nbest)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
"""
@ -69,6 +95,7 @@ import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
@ -147,6 +174,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
""",
)
@ -168,7 +196,8 @@ def get_parser():
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -176,7 +205,8 @@ def get_parser():
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -184,7 +214,8 @@ def get_parser():
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -205,9 +236,10 @@ def get_parser():
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for computed nbest oracle WER
when the decoding method is fast_beam_search_nbest_oracle.
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
@ -216,9 +248,11 @@ def get_parser():
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding_method is fast_beam_search_nbest_oracle.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
return parser
@ -252,8 +286,8 @@ def decode_one_batch(
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is
fast_beam_search or fast_beam_search_nbest_oracle.
only when --decoding_method is fast_beam_search,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -285,6 +319,20 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
@ -355,7 +403,7 @@ def decode_one_batch(
f"max_states_{params.max_states}"
): hyps
}
elif params.decoding_method == "fast_beam_search_nbest_oracle":
elif "fast_beam_search_nbest" in params.decoding_method:
return {
(
f"beam_{params.beam}_"
@ -389,7 +437,8 @@ def decode_dataset(
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -407,7 +456,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -499,6 +548,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
@ -513,7 +563,7 @@ def main():
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif params.decoding_method == "fast_beam_search_nbest_oracle":
elif "fast_beam_search_nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
@ -539,9 +589,9 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.unk_id()
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -583,10 +633,7 @@ def main():
model.device = device
model.unk_id = params.unk_id
if params.decoding_method in (
"fast_beam_search",
"fast_beam_search_nbest_oracle",
):
if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None

View File

@ -44,7 +44,7 @@ Usage:
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
@ -54,6 +54,32 @@ Usage:
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search (nbest)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
"""
@ -70,6 +96,8 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
@ -159,6 +187,8 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
""",
)
@ -178,7 +208,9 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -186,7 +218,8 @@ def get_parser():
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -194,7 +227,8 @@ def get_parser():
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -212,6 +246,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
return parser
@ -277,6 +331,35 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -332,6 +415,16 @@ def decode_one_batch(
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search_nbest" in params.decoding_method:
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -356,7 +449,8 @@ def decode_dataset(
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -374,7 +468,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -466,6 +560,8 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -475,10 +571,16 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "fast_beam_search_nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -592,7 +694,7 @@ def main():
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None

View File

@ -44,7 +44,7 @@ Usage:
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
@ -54,6 +54,32 @@ Usage:
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--num-paths 200 \
--nbest-scale 0.5
"""
@ -70,6 +96,8 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
@ -128,7 +156,7 @@ def get_parser():
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -159,6 +187,8 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
""",
)
@ -178,7 +208,9 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -186,7 +218,8 @@ def get_parser():
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -194,7 +227,8 @@ def get_parser():
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, or
fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -212,6 +246,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest or
fast_beam_search_nbest_oracle
""",
)
add_model_arguments(parser)
return parser
@ -279,6 +333,35 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -334,6 +417,16 @@ def decode_one_batch(
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search_nbest" in params.decoding_method:
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -358,7 +451,8 @@ def decode_dataset(
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -468,6 +562,8 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -477,10 +573,16 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "fast_beam_search_nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -594,7 +696,7 @@ def main():
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None

View File

@ -308,9 +308,7 @@ class Nbest(object):
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
)
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = self.shape.row_ids(1)
@ -609,7 +607,7 @@ def rescore_with_n_best_list(
num_paths:
Size of nbest list.
lm_scale_list:
A list of float representing LM score scales.
A list of floats representing LM score scales.
nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``.