mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add RNN-LM rescoring in fast beam search (#475)
This commit is contained in:
parent
aec222e2fe
commit
608473b4eb
@ -24,7 +24,7 @@ import torch
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import add_eos, add_sos, get_texts
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_one_best(
|
def fast_beam_search_one_best(
|
||||||
@ -46,7 +46,7 @@ def fast_beam_search_one_best(
|
|||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
@ -106,7 +106,7 @@ def fast_beam_search_nbest_LG(
|
|||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
@ -226,7 +226,7 @@ def fast_beam_search_nbest(
|
|||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
@ -311,7 +311,7 @@ def fast_beam_search_nbest_oracle(
|
|||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
@ -397,7 +397,7 @@ def fast_beam_search(
|
|||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
@ -1219,13 +1219,15 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
A lattice is first obtained using modified beam search, and then
|
A lattice is first obtained using fast beam search, num_path are selected
|
||||||
the shortest path within the lattice is used as the final output.
|
and rescored using a given language model. The shortest path within the
|
||||||
|
lattice is used as the final output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
@ -1350,3 +1352,190 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
ans[key] = hyps
|
ans[key] = hyps
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_with_nbest_rnn_rescoring(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
ngram_lm_scale_list: List[float],
|
||||||
|
num_paths: int,
|
||||||
|
G: k2.Fsa,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
rnn_lm_model: torch.nn.Module,
|
||||||
|
rnn_lm_scale_list: List[float],
|
||||||
|
oov_word: str = "<UNK>",
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
A lattice is first obtained using fast beam search, num_path are selected
|
||||||
|
and rescored using a given language model and a rnn-lm.
|
||||||
|
The shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi.
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
ngram_lm_scale_list:
|
||||||
|
A list of floats representing LM score scales.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
G:
|
||||||
|
An FsaVec containing only a single FSA. It is an n-gram LM.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
rnn_lm_model:
|
||||||
|
A rnn-lm model used for LM rescoring
|
||||||
|
rnn_lm_scale_list:
|
||||||
|
A list of floats representing RNN score scales.
|
||||||
|
oov_word:
|
||||||
|
OOV words are replaced with this word.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result in a dict, where the key has the form
|
||||||
|
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||||
|
ngram LM scale value used during decoding, i.e., 0.1.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
# at this point, nbest.fsa.scores are all zeros.
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa.scores contains acoustic scores
|
||||||
|
|
||||||
|
am_scores = nbest.tot_scores()
|
||||||
|
|
||||||
|
# Now we need to compute the LM scores of each path.
|
||||||
|
# (1) Get the token IDs of each Path. We assume the decoding_graph
|
||||||
|
# is an acceptor, i.e., lattice is also an acceptor
|
||||||
|
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
|
||||||
|
|
||||||
|
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
|
||||||
|
tokens = tokens.remove_values_leq(0) # remove -1 and 0
|
||||||
|
|
||||||
|
token_list: List[List[int]] = tokens.tolist()
|
||||||
|
word_list: List[List[str]] = sp.decode(token_list)
|
||||||
|
|
||||||
|
assert isinstance(oov_word, str), oov_word
|
||||||
|
assert oov_word in word_table, oov_word
|
||||||
|
oov_word_id = word_table[oov_word]
|
||||||
|
|
||||||
|
word_ids_list: List[List[int]] = []
|
||||||
|
|
||||||
|
for words in word_list:
|
||||||
|
this_word_ids = []
|
||||||
|
for w in words.split():
|
||||||
|
if w in word_table:
|
||||||
|
this_word_ids.append(word_table[w])
|
||||||
|
else:
|
||||||
|
this_word_ids.append(oov_word_id)
|
||||||
|
word_ids_list.append(this_word_ids)
|
||||||
|
|
||||||
|
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
|
||||||
|
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
|
||||||
|
|
||||||
|
num_unique_paths = len(word_ids_list)
|
||||||
|
|
||||||
|
b_to_a_map = torch.zeros(
|
||||||
|
num_unique_paths,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=lattice.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
rescored_word_fsas = k2.intersect_device(
|
||||||
|
a_fsas=G,
|
||||||
|
b_fsas=word_fsas_with_self_loops,
|
||||||
|
b_to_a_map=b_to_a_map,
|
||||||
|
sorted_match_a=True,
|
||||||
|
ret_arc_maps=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
|
||||||
|
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
|
||||||
|
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
|
||||||
|
use_double_scores=True,
|
||||||
|
log_semiring=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now RNN-LM
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
sos_id = sp.piece_to_id("sos_id")
|
||||||
|
eos_id = sp.piece_to_id("eos_id")
|
||||||
|
|
||||||
|
sos_tokens = add_sos(tokens, sos_id)
|
||||||
|
tokens_eos = add_eos(tokens, eos_id)
|
||||||
|
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
|
||||||
|
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
|
||||||
|
|
||||||
|
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
|
||||||
|
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
x_tokens = x_tokens.to(torch.int64)
|
||||||
|
y_tokens = y_tokens.to(torch.int64)
|
||||||
|
sentence_lengths = sentence_lengths.to(torch.int64)
|
||||||
|
|
||||||
|
rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths)
|
||||||
|
assert rnn_lm_nll.ndim == 2
|
||||||
|
assert rnn_lm_nll.shape[0] == len(token_list)
|
||||||
|
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
|
||||||
|
|
||||||
|
ans: Dict[str, List[List[int]]] = {}
|
||||||
|
for n_scale in ngram_lm_scale_list:
|
||||||
|
for rnn_scale in rnn_lm_scale_list:
|
||||||
|
key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}"
|
||||||
|
tot_scores = (
|
||||||
|
am_scores.values
|
||||||
|
+ n_scale * ngram_lm_scores
|
||||||
|
+ rnn_scale * rnn_lm_scores
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
|
||||||
|
ans[key] = hyps
|
||||||
|
|
||||||
|
return ans
|
||||||
|
@ -112,6 +112,7 @@ from beam_search import (
|
|||||||
fast_beam_search_nbest_oracle,
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
fast_beam_search_with_nbest_rescoring,
|
fast_beam_search_with_nbest_rescoring,
|
||||||
|
fast_beam_search_with_nbest_rnn_rescoring,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -125,8 +126,10 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
load_averaged_model,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -342,6 +345,62 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="rnn_lm/exp",
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the path to RNN LM exp dir.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-epoch",
|
||||||
|
type=int,
|
||||||
|
default=7,
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the checkpoint to use.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-avg",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the number of checkpoints to average.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-embedding-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Embedding dim of the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-hidden-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Hidden dim of the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-num-layers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of RNN layers the model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-tie-weights",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="""True to share the weights between the input embedding layer and the
|
||||||
|
last output linear layer
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -355,6 +414,7 @@ def decode_one_batch(
|
|||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
rnn_lm_model: torch.nn.Module = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -526,6 +586,30 @@ def decode_one_batch(
|
|||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
temperature=params.temperature,
|
temperature=params.temperature,
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring":
|
||||||
|
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
|
||||||
|
ngram_lm_scale_list += [0.01, 0.02, 0.05]
|
||||||
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
|
||||||
|
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
|
||||||
|
hyp_tokens = fast_beam_search_with_nbest_rnn_rescoring(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_states=params.max_states,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
ngram_lm_scale_list=ngram_lm_scale_list,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
G=G,
|
||||||
|
sp=sp,
|
||||||
|
word_table=word_table,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
|
rnn_lm_scale_list=ngram_lm_scale_list,
|
||||||
|
use_double_scores=True,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
temperature=params.temperature,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -571,7 +655,10 @@ def decode_one_batch(
|
|||||||
f"temperature_{params.temperature}"
|
f"temperature_{params.temperature}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
elif params.decoding_method in [
|
||||||
|
"fast_beam_search_with_nbest_rescoring",
|
||||||
|
"fast_beam_search_with_nbest_rnn_rescoring",
|
||||||
|
]:
|
||||||
prefix = (
|
prefix = (
|
||||||
f"beam_{params.beam}_"
|
f"beam_{params.beam}_"
|
||||||
f"max_contexts_{params.max_contexts}_"
|
f"max_contexts_{params.max_contexts}_"
|
||||||
@ -612,6 +699,7 @@ def decode_dataset(
|
|||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
rnn_lm_model: torch.nn.Module = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -666,6 +754,7 @@ def decode_dataset(
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
G=G,
|
G=G,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -816,6 +905,7 @@ def main():
|
|||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
"fast_beam_search_with_nbest_rescoring",
|
"fast_beam_search_with_nbest_rescoring",
|
||||||
|
"fast_beam_search_with_nbest_rnn_rescoring",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -919,7 +1009,10 @@ def main():
|
|||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
elif params.decoding_method in [
|
||||||
|
"fast_beam_search_with_nbest_rescoring",
|
||||||
|
"fast_beam_search_with_nbest_rnn_rescoring",
|
||||||
|
]:
|
||||||
logging.info(f"Loading word symbol table from {params.words_txt}")
|
logging.info(f"Loading word symbol table from {params.words_txt}")
|
||||||
word_table = k2.SymbolTable.from_file(params.words_txt)
|
word_table = k2.SymbolTable.from_file(params.words_txt)
|
||||||
|
|
||||||
@ -932,14 +1025,43 @@ def main():
|
|||||||
params.vocab_size - 1, device=device
|
params.vocab_size - 1, device=device
|
||||||
)
|
)
|
||||||
logging.info(f"G properties_str: {G.properties_str}")
|
logging.info(f"G properties_str: {G.properties_str}")
|
||||||
|
rnn_lm_model = None
|
||||||
|
if (
|
||||||
|
params.decoding_method
|
||||||
|
== "fast_beam_search_with_nbest_rnn_rescoring"
|
||||||
|
):
|
||||||
|
rnn_lm_model = RnnLmModel(
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
embedding_dim=params.rnn_lm_embedding_dim,
|
||||||
|
hidden_dim=params.rnn_lm_hidden_dim,
|
||||||
|
num_layers=params.rnn_lm_num_layers,
|
||||||
|
tie_weights=params.rnn_lm_tie_weights,
|
||||||
|
)
|
||||||
|
if params.rnn_lm_avg == 1:
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||||
|
rnn_lm_model,
|
||||||
|
)
|
||||||
|
rnn_lm_model.to(device)
|
||||||
|
else:
|
||||||
|
rnn_lm_model = load_averaged_model(
|
||||||
|
params.rnn_lm_exp_dir,
|
||||||
|
rnn_lm_model,
|
||||||
|
params.rnn_lm_epoch,
|
||||||
|
params.rnn_lm_avg,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
rnn_lm_model.eval()
|
||||||
else:
|
else:
|
||||||
word_table = None
|
word_table = None
|
||||||
decoding_graph = k2.trivial_graph(
|
decoding_graph = k2.trivial_graph(
|
||||||
params.vocab_size - 1, device=device
|
params.vocab_size - 1, device=device
|
||||||
)
|
)
|
||||||
|
rnn_lm_model = None
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
word_table = None
|
word_table = None
|
||||||
|
rnn_lm_model = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -965,6 +1087,7 @@ def main():
|
|||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
G=G,
|
G=G,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -1006,6 +1006,8 @@ def rescore_with_rnn_lm(
|
|||||||
An FsaVec with axes [utt][state][arc].
|
An FsaVec with axes [utt][state][arc].
|
||||||
num_paths:
|
num_paths:
|
||||||
Number of paths to extract from the given lattice for rescoring.
|
Number of paths to extract from the given lattice for rescoring.
|
||||||
|
rnn_lm_model:
|
||||||
|
A rnn-lm model used for LM rescoring
|
||||||
model:
|
model:
|
||||||
A transformer model. See the class "Transformer" in
|
A transformer model. See the class "Transformer" in
|
||||||
conformer_ctc/transformer.py for its interface.
|
conformer_ctc/transformer.py for its interface.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user