mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Compute the Nbest oracle WER for RNN-T decoding.
This commit is contained in:
parent
e9f0975868
commit
b1c3705fbe
@ -22,11 +22,11 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.decode import one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search(
|
def fast_beam_search_one_best(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
decoding_graph: k2.Fsa,
|
decoding_graph: k2.Fsa,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -37,6 +37,9 @@ def fast_beam_search(
|
|||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
the shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
@ -56,6 +59,153 @@ def fast_beam_search(
|
|||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(lattice)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_nbest_oracle(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
num_paths: int,
|
||||||
|
ref_texts: List[List[int]],
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
we select `num_paths` linear paths from the lattice. The path
|
||||||
|
that has the minimum edit distance with the given reference transcript
|
||||||
|
is used as the output.
|
||||||
|
|
||||||
|
This is the best result we can achieve for any nbest based rescoring
|
||||||
|
methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
ref_texts:
|
||||||
|
A list-of-list of integers containing the reference transcripts.
|
||||||
|
If the decoding_graph is a trivial_graph, the integer ID is the
|
||||||
|
BPE token ID.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We assume the labels of nbest.fsa are token IDs and the aux_labels
|
||||||
|
# are word IDs.
|
||||||
|
word_fsa = k2.invert(nbest.fsa)
|
||||||
|
word_ids = get_texts(word_fsa, return_ragged=True)
|
||||||
|
|
||||||
|
hyps = k2.levenshtein_graph(word_ids)
|
||||||
|
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
|
||||||
|
|
||||||
|
levenshtein_alignment = k2.levenshtein_alignment(
|
||||||
|
refs=refs,
|
||||||
|
hyps=hyps,
|
||||||
|
hyp_to_ref_map=nbest.shape.row_ids(1),
|
||||||
|
sorted_match_ref=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tot_scores = levenshtein_alignment.get_tot_scores(
|
||||||
|
use_double_scores=False, log_semiring=False
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
Returns:
|
||||||
|
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
||||||
|
lattice. Note: When the input graph is a TrivialGraph, the returned
|
||||||
|
lattice is actually an acceptor.
|
||||||
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
@ -104,9 +254,7 @@ def fast_beam_search(
|
|||||||
decoding_streams.terminate_and_flush_to_streams()
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
return lattice
|
||||||
hyps = get_texts(best_path)
|
|
||||||
return hyps
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
@ -131,6 +279,7 @@ def greedy_search(
|
|||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
@ -171,7 +320,7 @@ def greedy_search(
|
|||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
if y != blank_id:
|
if y not in (blank_id, unk_id):
|
||||||
hyp.append(y)
|
hyp.append(y)
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[hyp[-context_size:]], device=device
|
[hyp[-context_size:]], device=device
|
||||||
@ -212,6 +361,7 @@ def greedy_search_batch(
|
|||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
||||||
@ -240,7 +390,7 @@ def greedy_search_batch(
|
|||||||
y = logits.argmax(dim=1).tolist()
|
y = logits.argmax(dim=1).tolist()
|
||||||
emitted = False
|
emitted = False
|
||||||
for i, v in enumerate(y):
|
for i, v in enumerate(y):
|
||||||
if v != blank_id:
|
if v not in (blank_id, unk_id):
|
||||||
hyps[i].append(v)
|
hyps[i].append(v)
|
||||||
emitted = True
|
emitted = True
|
||||||
if emitted:
|
if emitted:
|
||||||
@ -433,6 +583,7 @@ def modified_beam_search(
|
|||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = model.device
|
device = model.device
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
@ -515,7 +666,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
if new_token != blank_id:
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[k]
|
new_log_prob = topk_log_probs[k]
|
||||||
@ -556,6 +707,7 @@ def _deprecated_modified_beam_search(
|
|||||||
# support only batch_size == 1 for now
|
# support only batch_size == 1 for now
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
@ -626,7 +778,7 @@ def _deprecated_modified_beam_search(
|
|||||||
hyp = A[topk_hyp_indexes[i]]
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[i]
|
new_token = topk_token_indexes[i]
|
||||||
if new_token != blank_id:
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
new_log_prob = topk_log_probs[i]
|
new_log_prob = topk_log_probs[i]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
@ -663,6 +815,7 @@ def beam_search(
|
|||||||
# support only batch_size == 1 for now
|
# support only batch_size == 1 for now
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
@ -748,7 +901,7 @@ def beam_search(
|
|||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
for i, v in zip(indices.tolist(), values.tolist()):
|
for i, v in zip(indices.tolist(), values.tolist()):
|
||||||
if i == blank_id:
|
if i in (blank_id, unk_id):
|
||||||
continue
|
continue
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
new_log_prob = y_star.log_prob + v
|
||||||
|
@ -69,7 +69,8 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_nbest_oracle,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -145,6 +146,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -164,7 +166,8 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is
|
||||||
|
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -172,7 +175,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -180,7 +183,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search""",
|
fast_beam_search or fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -198,6 +201,23 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Number of paths for computed nbest oracle WER
|
||||||
|
when the decoding method is fast_beam_search_nbest_oracle.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding_method is fast_beam_search_nbest_oracle.
|
||||||
|
""",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -231,7 +251,8 @@ def decode_one_batch(
|
|||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is
|
||||||
|
fast_beam_search or fast_beam_search_nbest_oracle.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -252,7 +273,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -263,6 +284,21 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -316,6 +352,16 @@ def decode_one_batch(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}_"
|
||||||
|
f"num_paths_{params.num_paths}_"
|
||||||
|
f"nbest_scale_{params.nbest_scale}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
@ -450,15 +496,22 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if params.decoding_method == "fast_beam_search":
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
@ -479,6 +532,7 @@ def main():
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.unk_id()
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -506,8 +560,12 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
model.unk_id = params.unk_id
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method in (
|
||||||
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
):
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user