mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Add fast_beam_search_nbest.
This commit is contained in:
parent
53f38c01d2
commit
1bf2e17437
@ -75,6 +75,86 @@ def fast_beam_search_one_best(
|
|||||||
return hyps
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_nbest(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
num_paths: int,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
The process to get the results is:
|
||||||
|
- (1) Use fast beam search to get a lattice
|
||||||
|
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
|
||||||
|
- (3) Unique the selected paths
|
||||||
|
- (4) Intersect the selected paths with the lattice and compute the
|
||||||
|
shortest path from the intersection result
|
||||||
|
- (5) The path with the largest score is used as the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# at this point, nbest.fsa.scores are all zeros.
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa.scores contains acoustic scores
|
||||||
|
|
||||||
|
max_indexes = nbest.tot_scores().argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_oracle(
|
def fast_beam_search_nbest_oracle(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
decoding_graph: k2.Fsa,
|
decoding_graph: k2.Fsa,
|
||||||
|
@ -82,6 +82,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
@ -250,6 +251,26 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search and
|
||||||
|
--use-LG is True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search and
|
||||||
|
--use-LG is True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -307,21 +328,32 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
if not params.use_LG:
|
||||||
model=model,
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
decoding_graph=decoding_graph,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam,
|
encoder_out_lens=encoder_out_lens,
|
||||||
max_contexts=params.max_contexts,
|
beam=params.beam,
|
||||||
max_states=params.max_states,
|
max_contexts=params.max_contexts,
|
||||||
)
|
max_states=params.max_states,
|
||||||
if params.use_LG:
|
)
|
||||||
for hyp in hyp_tokens:
|
|
||||||
hyps.append([word_table[i] for i in hyp])
|
|
||||||
else:
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
|
@ -37,7 +37,7 @@ def fast_beam_search_one_best(
|
|||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using modified beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
the shortest path within the lattice is used as the final output.
|
the shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -74,6 +74,86 @@ def fast_beam_search_one_best(
|
|||||||
return hyps
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_nbest(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
num_paths: int,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
The process to get the results is:
|
||||||
|
- (1) Use fast beam search to get a lattice
|
||||||
|
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
|
||||||
|
- (3) Unique the selected paths
|
||||||
|
- (4) Intersect the selected paths with the lattice and compute the
|
||||||
|
shortest path from the intersection result
|
||||||
|
- (5) The path with the largest score is used as the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# at this point, nbest.fsa.scores are all zeros.
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa.scores contains acoustic scores
|
||||||
|
|
||||||
|
max_indexes = nbest.tot_scores().argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_oracle(
|
def fast_beam_search_nbest_oracle(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
decoding_graph: k2.Fsa,
|
decoding_graph: k2.Fsa,
|
||||||
@ -89,7 +169,7 @@ def fast_beam_search_nbest_oracle(
|
|||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using modified beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
we select `num_paths` linear paths from the lattice. The path
|
we select `num_paths` linear paths from the lattice. The path
|
||||||
that has the minimum edit distance with the given reference transcript
|
that has the minimum edit distance with the given reference transcript
|
||||||
is used as the output.
|
is used as the output.
|
||||||
|
@ -43,7 +43,7 @@ Usage:
|
|||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search (one best)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
@ -53,6 +53,32 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -69,6 +95,8 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
@ -145,6 +173,8 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -164,7 +194,9 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -172,7 +204,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -180,7 +213,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -198,6 +232,26 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -231,7 +285,8 @@ def decode_one_batch(
|
|||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -263,6 +318,35 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -318,6 +402,16 @@ def decode_one_batch(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}_"
|
||||||
|
f"num_paths_{params.num_paths}_"
|
||||||
|
f"nbest_scale_{params.nbest_scale}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
@ -342,7 +436,8 @@ def decode_dataset(
|
|||||||
The BPE model.
|
The BPE model.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -360,7 +455,7 @@ def decode_dataset(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 50
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 10
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -452,6 +547,8 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
@ -461,10 +558,16 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if params.decoding_method == "fast_beam_search":
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += (
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
@ -528,7 +631,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if "fast_beam_search" in params.decoding_method:
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
@ -19,40 +19,66 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search (one best)
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -69,6 +95,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle,
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
@ -147,6 +174,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -168,7 +196,8 @@ def get_parser():
|
|||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is
|
Used only when --decoding-method is
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -176,7 +205,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -184,7 +214,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -205,9 +236,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-paths",
|
"--num-paths",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=200,
|
||||||
help="""Number of paths for computed nbest oracle WER
|
help="""Number of paths for nbest decoding.
|
||||||
when the decoding method is fast_beam_search_nbest_oracle.
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -216,9 +248,11 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help="""Scale applied to lattice scores when computing nbest paths.
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
Used only when the decoding_method is fast_beam_search_nbest_oracle.
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -252,8 +286,8 @@ def decode_one_batch(
|
|||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is
|
only when --decoding_method is fast_beam_search,
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle.
|
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -285,6 +319,20 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
model=model,
|
model=model,
|
||||||
@ -355,7 +403,7 @@ def decode_one_batch(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
return {
|
return {
|
||||||
(
|
(
|
||||||
f"beam_{params.beam}_"
|
f"beam_{params.beam}_"
|
||||||
@ -389,7 +437,8 @@ def decode_dataset(
|
|||||||
The BPE model.
|
The BPE model.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -407,7 +456,7 @@ def decode_dataset(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 50
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 10
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -499,6 +548,7 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
@ -513,7 +563,7 @@ def main():
|
|||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
@ -539,9 +589,9 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.unk_id()
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -583,10 +633,7 @@ def main():
|
|||||||
model.device = device
|
model.device = device
|
||||||
model.unk_id = params.unk_id
|
model.unk_id = params.unk_id
|
||||||
|
|
||||||
if params.decoding_method in (
|
if "fast_beam_search" in params.decoding_method:
|
||||||
"fast_beam_search",
|
|
||||||
"fast_beam_search_nbest_oracle",
|
|
||||||
):
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
@ -44,7 +44,7 @@ Usage:
|
|||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search (one best)
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
@ -54,6 +54,32 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless4/decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless4/decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +96,8 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
@ -159,6 +187,8 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -178,7 +208,9 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -186,7 +218,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -194,7 +227,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -212,6 +246,26 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -277,6 +331,35 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -332,6 +415,16 @@ def decode_one_batch(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}_"
|
||||||
|
f"num_paths_{params.num_paths}_"
|
||||||
|
f"nbest_scale_{params.nbest_scale}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
@ -356,7 +449,8 @@ def decode_dataset(
|
|||||||
The BPE model.
|
The BPE model.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -374,7 +468,7 @@ def decode_dataset(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 50
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 10
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -466,6 +560,8 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
@ -475,10 +571,16 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if params.decoding_method == "fast_beam_search":
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += (
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
@ -592,7 +694,7 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if "fast_beam_search" in params.decoding_method:
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
@ -44,7 +44,7 @@ Usage:
|
|||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search (one best)
|
||||||
./pruned_transducer_stateless5/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
@ -54,6 +54,32 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +96,8 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
@ -128,7 +156,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="Whether to load averaged model. Currently it only supports "
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
@ -159,6 +187,8 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -178,7 +208,9 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -186,7 +218,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -194,7 +227,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search, fast_beam_search_nbest, or
|
||||||
|
fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -212,6 +246,26 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest or
|
||||||
|
fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -279,6 +333,35 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -334,6 +417,16 @@ def decode_one_batch(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}_"
|
||||||
|
f"num_paths_{params.num_paths}_"
|
||||||
|
f"nbest_scale_{params.nbest_scale}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
@ -358,7 +451,8 @@ def decode_dataset(
|
|||||||
The BPE model.
|
The BPE model.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -468,6 +562,8 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
@ -477,10 +573,16 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if params.decoding_method == "fast_beam_search":
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "fast_beam_search_nbest" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += (
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
@ -594,7 +696,7 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if "fast_beam_search" in params.decoding_method:
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
@ -308,9 +308,7 @@ class Nbest(object):
|
|||||||
del word_fsa.aux_labels
|
del word_fsa.aux_labels
|
||||||
|
|
||||||
word_fsa.scores.zero_()
|
word_fsa.scores.zero_()
|
||||||
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
||||||
word_fsa
|
|
||||||
)
|
|
||||||
|
|
||||||
path_to_utt_map = self.shape.row_ids(1)
|
path_to_utt_map = self.shape.row_ids(1)
|
||||||
|
|
||||||
@ -609,7 +607,7 @@ def rescore_with_n_best_list(
|
|||||||
num_paths:
|
num_paths:
|
||||||
Size of nbest list.
|
Size of nbest list.
|
||||||
lm_scale_list:
|
lm_scale_list:
|
||||||
A list of float representing LM score scales.
|
A list of floats representing LM score scales.
|
||||||
nbest_scale:
|
nbest_scale:
|
||||||
Scale to be applied to ``lattice.score`` when sampling paths
|
Scale to be applied to ``lattice.score`` when sampling paths
|
||||||
using ``k2.random_paths``.
|
using ``k2.random_paths``.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user