Add fast_beam_search_nbest. (#420)

* Add fast_beam_search_nbest.

* Fix CI errors.

* Fix CI errors.

* More fixes.

* Small fixes.

* Support using log_add in LG decoding with fast_beam_search.

* Support LG decoding in pruned_transducer_stateless

* Support LG for pruned_transducer_stateless2.

* Support LG for fast beam search.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2022-06-22 00:09:25 +08:00 committed by GitHub
parent 7100c33820
commit dc89b61b80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1298 additions and 194 deletions

View File

@ -32,6 +32,12 @@ for sym in 1 2 3; do
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --bpe-model $repo/data/lang_bpe_500/bpe.model \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -33,13 +33,13 @@ jobs:
# disable macOS test for now. # disable macOS test for now.
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.7, 3.8] python-version: [3.7, 3.8]
torch: ["1.8.0", "1.10.0"] torch: ["1.8.0", "1.11.0"]
torchaudio: ["0.8.0", "0.10.0"] torchaudio: ["0.8.0", "0.11.0"]
k2-version: ["1.9.dev20211101"] k2-version: ["1.15.1.dev20220427"]
exclude: exclude:
- torch: "1.8.0" - torch: "1.8.0"
torchaudio: "0.10.0" torchaudio: "0.11.0"
- torch: "1.10.0" - torch: "1.11.0"
torchaudio: "0.8.0" torchaudio: "0.8.0"
fail-fast: false fail-fast: false
@ -67,7 +67,7 @@ jobs:
# numpy 1.20.x does not support python 3.6 # numpy 1.20.x does not support python 3.6
pip install numpy==1.19 pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
else else
pip install torchaudio==${{ matrix.torchaudio }} pip install torchaudio==${{ matrix.torchaudio }}

View File

@ -75,6 +75,202 @@ def fast_beam_search_one_best(
return hyps return hyps
def fast_beam_search_nbest_LG(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores,
log_semiring=True, # Note: we always use True
)
# See https://github.com/k2-fsa/icefall/pull/420 for why
# we always use log_semiring=True
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle( def fast_beam_search_nbest_oracle(
model: Transducer, model: Transducer,
decoding_graph: k2.Fsa, decoding_graph: k2.Fsa,

View File

@ -50,20 +50,44 @@ Usage:
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search using LG (5) fast beam search (nbest)
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--use-LG True \
--use-max False \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search_nbest \
--beam 8 \ --beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
""" """
@ -82,6 +106,9 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -99,7 +126,6 @@ from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -153,7 +179,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=Path,
default="data/lang_bpe_500", default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph", help="The lang dir containing word table and LG graph",
) )
@ -167,6 +193,11 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -182,30 +213,13 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search,
) fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
parser.add_argument(
"--use-LG",
type=str2bool,
default=False,
help="""Whether to use an LG graph for FSA-based beam search.
Used only when --decoding_method is fast_beam_search. If setting true,
it assumes there is an LG.pt file in lang_dir.""",
)
parser.add_argument(
"--use-max",
type=str2bool,
default=False,
help="""If True, use max-op to select the hypothesis that have the
max log_prob in case of duplicate hypotheses.
If False, use log_add.
Used only for beam_search, modified_beam_search, and fast_beam_search
""", """,
) )
@ -214,7 +228,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search. Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -222,9 +236,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -232,7 +247,8 @@ def get_parser():
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -250,6 +266,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
return parser return parser
@ -286,7 +320,8 @@ def decode_one_batch(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -299,6 +334,7 @@ def decode_one_batch(
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
@ -316,10 +352,49 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
if params.use_LG: for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp]) hyps.append([word_table[i] for i in hyp])
else: elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif ( elif (
@ -339,7 +414,6 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
use_max=params.use_max,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -361,7 +435,6 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
use_max=params.use_max,
) )
else: else:
raise ValueError( raise ValueError(
@ -371,14 +444,17 @@ def decode_one_batch(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif "fast_beam_search" in params.decoding_method:
return { key = f"beam_{params.beam}_"
( key += f"max_contexts_{params.max_contexts}_"
f"beam_{params.beam}_" key += f"max_states_{params.max_states}"
f"max_contexts_{params.max_contexts}_" if "nbest" in params.decoding_method:
f"max_states_{params.max_states}" key += f"_num_paths_{params.num_paths}_"
): hyps key += f"nbest_scale_{params.nbest_scale}"
} if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -406,7 +482,8 @@ def decode_dataset(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -424,7 +501,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 50 log_interval = 50
else: else:
log_interval = 10 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -517,6 +594,9 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -527,16 +607,18 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-use-LG-{params.use_LG}"
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-use-max-{params.use_max}" if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
) )
params.suffix += f"-use-max-{params.use_max}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -596,12 +678,14 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
if params.use_LG: if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/LG.pt", map_location=device) torch.load(lg_filename, map_location=device)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -37,7 +37,7 @@ def fast_beam_search_one_best(
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then A lattice is first obtained using fast beam search, and then
the shortest path within the lattice is used as the final output. the shortest path within the lattice is used as the final output.
Args: Args:
@ -74,6 +74,202 @@ def fast_beam_search_one_best(
return hyps return hyps
def fast_beam_search_nbest_LG(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores,
log_semiring=True, # Note: we always use True
)
# See https://github.com/k2-fsa/icefall/pull/420 for why
# we always use log_semiring=True
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle( def fast_beam_search_nbest_oracle(
model: Transducer, model: Transducer,
decoding_graph: k2.Fsa, decoding_graph: k2.Fsa,
@ -89,7 +285,7 @@ def fast_beam_search_nbest_oracle(
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then A lattice is first obtained using fast beam search, and then
we select `num_paths` linear paths from the lattice. The path we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript that has the minimum edit distance with the given reference transcript
is used as the output. is used as the output.

View File

@ -43,16 +43,53 @@ Usage:
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search (one best)
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -69,6 +106,9 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -81,6 +121,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -136,6 +177,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -145,6 +193,11 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -160,27 +213,42 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -198,6 +266,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
return parser return parser
@ -206,6 +292,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -229,9 +316,12 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -263,6 +353,49 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -318,6 +451,17 @@ def decode_one_batch(
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
): hyps ): hyps
} }
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -327,6 +471,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -340,9 +485,12 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -360,7 +508,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 50 log_interval = 50
else: else:
log_interval = 10 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -370,6 +518,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -452,6 +601,9 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -465,6 +617,11 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -528,10 +685,24 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -553,6 +724,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -43,16 +43,53 @@ Usage:
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search (one best)
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -69,6 +106,8 @@ import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -83,6 +122,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -138,6 +178,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -147,7 +194,11 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -163,28 +214,42 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is Used only when --decoding-method is fast_beam_search,
fast_beam_search or fast_beam_search_nbest_oracle""", fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -205,10 +270,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-paths", "--num-paths",
type=int, type=int,
default=100, default=200,
help="""Number of paths for computed nbest oracle WER help="""Number of paths for nbest decoding.
when the decoding method is fast_beam_search_nbest_oracle. Used only when the decoding method is fast_beam_search_nbest,
""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -216,9 +281,10 @@ def get_parser():
type=float, type=float,
default=0.5, default=0.5,
help="""Scale applied to lattice scores when computing nbest paths. help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding_method is fast_beam_search_nbest_oracle. Used only when the decoding method is fast_beam_search_nbest,
""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
return parser return parser
@ -227,6 +293,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -250,10 +317,12 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search or fast_beam_search_nbest_oracle. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -285,6 +354,34 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle( hyp_tokens = fast_beam_search_nbest_oracle(
model=model, model=model,
@ -355,16 +452,25 @@ def decode_one_batch(
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search":
return { return {
( (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_" f"max_states_{params.max_states}"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps ): hyps
} }
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -374,6 +480,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -387,9 +494,12 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -407,7 +517,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 50 log_interval = 50
else: else:
log_interval = 10 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -417,6 +527,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -499,6 +610,8 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
@ -509,16 +622,15 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif params.decoding_method == "fast_beam_search_nbest_oracle": if "nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -539,9 +651,9 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.unk_id() params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -583,13 +695,24 @@ def main():
model.device = device model.device = device
model.unk_id = params.unk_id model.unk_id = params.unk_id
if params.decoding_method in ( if "fast_beam_search" in params.decoding_method:
"fast_beam_search", if params.decoding_method == "fast_beam_search_nbest_LG":
"fast_beam_search_nbest_oracle", lexicon = Lexicon(params.lang_dir)
): word_table = lexicon.word_table
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -612,6 +735,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -44,16 +44,53 @@ Usage:
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search (one best)
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless4/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -70,6 +107,9 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -83,6 +123,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -150,6 +191,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -159,6 +207,11 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -174,27 +227,42 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -212,6 +280,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
return parser return parser
@ -220,6 +306,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -243,9 +330,12 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -277,6 +367,49 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -324,14 +457,17 @@ def decode_one_batch(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif "fast_beam_search" in params.decoding_method:
return { key = f"beam_{params.beam}_"
( key += f"max_contexts_{params.max_contexts}_"
f"beam_{params.beam}_" key += f"max_states_{params.max_states}"
f"max_contexts_{params.max_contexts}_" if "nbest" in params.decoding_method:
f"max_states_{params.max_states}" key += f"_num_paths_{params.num_paths}_"
): hyps key += f"nbest_scale_{params.nbest_scale}"
} if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -341,6 +477,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -354,9 +491,12 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -374,7 +514,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 50 log_interval = 50
else: else:
log_interval = 10 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -385,6 +525,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table,
batch=batch, batch=batch,
) )
@ -466,6 +607,9 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -479,6 +623,11 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -592,10 +741,24 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -617,6 +780,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -44,16 +44,53 @@ Usage:
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search (one best)
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -70,6 +107,9 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -83,6 +123,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -128,7 +169,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=True,
help="Whether to load averaged model. Currently it only supports " help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -150,6 +191,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -159,6 +207,11 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -174,27 +227,42 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -212,6 +280,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -222,6 +308,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -245,9 +332,12 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -279,6 +369,49 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -326,14 +459,17 @@ def decode_one_batch(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif "fast_beam_search" in params.decoding_method:
return { key = f"beam_{params.beam}_"
( key += f"max_contexts_{params.max_contexts}_"
f"beam_{params.beam}_" key += f"max_states_{params.max_states}"
f"max_contexts_{params.max_contexts}_" if "nbest" in params.decoding_method:
f"max_states_{params.max_states}" key += f"_num_paths_{params.num_paths}_"
): hyps key += f"nbest_scale_{params.nbest_scale}"
} if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -343,6 +479,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -356,9 +493,12 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -387,6 +527,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table,
batch=batch, batch=batch,
) )
@ -468,6 +609,9 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -481,6 +625,11 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -594,10 +743,24 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -619,6 +782,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -308,9 +308,7 @@ class Nbest(object):
del word_fsa.aux_labels del word_fsa.aux_labels
word_fsa.scores.zero_() word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
word_fsa
)
path_to_utt_map = self.shape.row_ids(1) path_to_utt_map = self.shape.row_ids(1)
@ -609,7 +607,7 @@ def rescore_with_n_best_list(
num_paths: num_paths:
Size of nbest list. Size of nbest list.
lm_scale_list: lm_scale_list:
A list of float representing LM score scales. A list of floats representing LM score scales.
nbest_scale: nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``. using ``k2.random_paths``.