mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
remove unused decoding methods
This commit is contained in:
parent
92f6128127
commit
d6b88aaa98
@ -20,290 +20,11 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
from model import SURT
|
from model import SURT
|
||||||
|
|
||||||
from icefall import NgramLm, NgramLmStateCost
|
from icefall import NgramLmStateCost
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.utils import DecodingResults
|
||||||
from icefall.lm_wrapper import LmScorer
|
|
||||||
from icefall.utils import (
|
|
||||||
DecodingResults,
|
|
||||||
add_eos,
|
|
||||||
add_sos,
|
|
||||||
get_texts,
|
|
||||||
get_texts_with_timestamp,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_one_best(
|
|
||||||
model: SURT,
|
|
||||||
decoding_graph: k2.Fsa,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
encoder_out_lens: torch.Tensor,
|
|
||||||
beam: float,
|
|
||||||
max_states: int,
|
|
||||||
max_contexts: int,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
return_timestamps: bool = False,
|
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
|
||||||
|
|
||||||
A lattice is first obtained using fast beam search, and then
|
|
||||||
the shortest path within the lattice is used as the final output.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
An instance of `SURT`.
|
|
||||||
decoding_graph:
|
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
|
||||||
encoder_out:
|
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
|
||||||
encoder_out_lens:
|
|
||||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
|
||||||
before padding.
|
|
||||||
beam:
|
|
||||||
Beam value, similar to the beam used in Kaldi..
|
|
||||||
max_states:
|
|
||||||
Max states per stream per frame.
|
|
||||||
max_contexts:
|
|
||||||
Max contexts pre stream per frame.
|
|
||||||
temperature:
|
|
||||||
Softmax temperature.
|
|
||||||
return_timestamps:
|
|
||||||
Whether to return timestamps.
|
|
||||||
Returns:
|
|
||||||
If return_timestamps is False, return the decoded result.
|
|
||||||
Else, return a DecodingResults object containing
|
|
||||||
decoded result and corresponding timestamps.
|
|
||||||
"""
|
|
||||||
lattice = fast_beam_search(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=beam,
|
|
||||||
max_states=max_states,
|
|
||||||
max_contexts=max_contexts,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
|
||||||
|
|
||||||
if not return_timestamps:
|
|
||||||
return get_texts(best_path)
|
|
||||||
else:
|
|
||||||
return get_texts_with_timestamp(best_path)
|
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_LG(
|
|
||||||
model: SURT,
|
|
||||||
decoding_graph: k2.Fsa,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
encoder_out_lens: torch.Tensor,
|
|
||||||
beam: float,
|
|
||||||
max_states: int,
|
|
||||||
max_contexts: int,
|
|
||||||
num_paths: int,
|
|
||||||
nbest_scale: float = 0.5,
|
|
||||||
use_double_scores: bool = True,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
return_timestamps: bool = False,
|
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
|
||||||
The process to get the results is:
|
|
||||||
- (1) Use fast beam search to get a lattice
|
|
||||||
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
|
|
||||||
- (3) Unique the selected paths
|
|
||||||
- (4) Intersect the selected paths with the lattice and compute the
|
|
||||||
shortest path from the intersection result
|
|
||||||
- (5) The path with the largest score is used as the decoding output.
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
An instance of `SURT`.
|
|
||||||
decoding_graph:
|
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
|
||||||
encoder_out:
|
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
|
||||||
encoder_out_lens:
|
|
||||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
|
||||||
before padding.
|
|
||||||
beam:
|
|
||||||
Beam value, similar to the beam used in Kaldi..
|
|
||||||
max_states:
|
|
||||||
Max states per stream per frame.
|
|
||||||
max_contexts:
|
|
||||||
Max contexts pre stream per frame.
|
|
||||||
num_paths:
|
|
||||||
Number of paths to extract from the decoded lattice.
|
|
||||||
nbest_scale:
|
|
||||||
It's the scale applied to the lattice.scores. A smaller value
|
|
||||||
yields more unique paths.
|
|
||||||
use_double_scores:
|
|
||||||
True to use double precision for computation. False to use
|
|
||||||
single precision.
|
|
||||||
temperature:
|
|
||||||
Softmax temperature.
|
|
||||||
return_timestamps:
|
|
||||||
Whether to return timestamps.
|
|
||||||
Returns:
|
|
||||||
If return_timestamps is False, return the decoded result.
|
|
||||||
Else, return a DecodingResults object containing
|
|
||||||
decoded result and corresponding timestamps.
|
|
||||||
"""
|
|
||||||
lattice = fast_beam_search(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=beam,
|
|
||||||
max_states=max_states,
|
|
||||||
max_contexts=max_contexts,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=num_paths,
|
|
||||||
use_double_scores=use_double_scores,
|
|
||||||
nbest_scale=nbest_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The following code is modified from nbest.intersect()
|
|
||||||
word_fsa = k2.invert(nbest.fsa)
|
|
||||||
if hasattr(lattice, "aux_labels"):
|
|
||||||
# delete token IDs as it is not needed
|
|
||||||
del word_fsa.aux_labels
|
|
||||||
word_fsa.scores.zero_()
|
|
||||||
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
|
||||||
path_to_utt_map = nbest.shape.row_ids(1)
|
|
||||||
|
|
||||||
if hasattr(lattice, "aux_labels"):
|
|
||||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
|
||||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
|
||||||
inv_lattice = k2.invert(lattice)
|
|
||||||
inv_lattice = k2.arc_sort(inv_lattice)
|
|
||||||
else:
|
|
||||||
inv_lattice = k2.arc_sort(lattice)
|
|
||||||
|
|
||||||
if inv_lattice.shape[0] == 1:
|
|
||||||
path_lattice = k2.intersect_device(
|
|
||||||
inv_lattice,
|
|
||||||
word_fsa_with_epsilon_loops,
|
|
||||||
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
|
||||||
sorted_match_a=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
path_lattice = k2.intersect_device(
|
|
||||||
inv_lattice,
|
|
||||||
word_fsa_with_epsilon_loops,
|
|
||||||
b_to_a_map=path_to_utt_map,
|
|
||||||
sorted_match_a=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# path_lattice has word IDs as labels and token IDs as aux_labels
|
|
||||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
|
||||||
tot_scores = path_lattice.get_tot_scores(
|
|
||||||
use_double_scores=use_double_scores,
|
|
||||||
log_semiring=True, # Note: we always use True
|
|
||||||
)
|
|
||||||
# See https://github.com/k2-fsa/icefall/pull/420 for why
|
|
||||||
# we always use log_semiring=True
|
|
||||||
|
|
||||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
|
||||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
|
||||||
|
|
||||||
if not return_timestamps:
|
|
||||||
return get_texts(best_path)
|
|
||||||
else:
|
|
||||||
return get_texts_with_timestamp(best_path)
|
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search(
|
|
||||||
model: SURT,
|
|
||||||
decoding_graph: k2.Fsa,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
encoder_out_lens: torch.Tensor,
|
|
||||||
beam: float,
|
|
||||||
max_states: int,
|
|
||||||
max_contexts: int,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
) -> k2.Fsa:
|
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
An instance of `SURT`.
|
|
||||||
decoding_graph:
|
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
|
||||||
encoder_out:
|
|
||||||
A tensor of shape (N, T, C) from the encoder.
|
|
||||||
encoder_out_lens:
|
|
||||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
|
||||||
before padding.
|
|
||||||
beam:
|
|
||||||
Beam value, similar to the beam used in Kaldi..
|
|
||||||
max_states:
|
|
||||||
Max states per stream per frame.
|
|
||||||
max_contexts:
|
|
||||||
Max contexts pre stream per frame.
|
|
||||||
temperature:
|
|
||||||
Softmax temperature.
|
|
||||||
Returns:
|
|
||||||
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
|
||||||
lattice. Note: When the input graph is a TrivialGraph, the returned
|
|
||||||
lattice is actually an acceptor.
|
|
||||||
"""
|
|
||||||
assert encoder_out.ndim == 3
|
|
||||||
|
|
||||||
context_size = model.decoder.context_size
|
|
||||||
vocab_size = model.decoder.vocab_size
|
|
||||||
|
|
||||||
B, T, C = encoder_out.shape
|
|
||||||
|
|
||||||
config = k2.RnntDecodingConfig(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
decoder_history_len=context_size,
|
|
||||||
beam=beam,
|
|
||||||
max_contexts=max_contexts,
|
|
||||||
max_states=max_states,
|
|
||||||
)
|
|
||||||
individual_streams = []
|
|
||||||
for i in range(B):
|
|
||||||
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
|
||||||
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
|
||||||
|
|
||||||
for t in range(T):
|
|
||||||
# shape is a RaggedShape of shape (B, context)
|
|
||||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
|
||||||
shape, contexts = decoding_streams.get_contexts()
|
|
||||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
|
||||||
contexts = contexts.to(torch.int64)
|
|
||||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
|
||||||
decoder_out = model.decoder(contexts, need_pad=False)
|
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
|
||||||
# current_encoder_out is of shape
|
|
||||||
# (shape.NumElements(), 1, joiner_dim)
|
|
||||||
# fmt: off
|
|
||||||
current_encoder_out = torch.index_select(
|
|
||||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
logits = model.joiner(
|
|
||||||
current_encoder_out.unsqueeze(2),
|
|
||||||
decoder_out.unsqueeze(1),
|
|
||||||
project_input=False,
|
|
||||||
)
|
|
||||||
logits = logits.squeeze(1).squeeze(1)
|
|
||||||
log_probs = (logits / temperature).log_softmax(dim=-1)
|
|
||||||
decoding_streams.advance(log_probs)
|
|
||||||
decoding_streams.terminate_and_flush_to_streams()
|
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
|
||||||
|
|
||||||
return lattice
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
@ -689,277 +410,6 @@ def modified_beam_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search_LODR(
|
|
||||||
model: SURT,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
encoder_out_lens: torch.Tensor,
|
|
||||||
LODR_lm: NgramLm,
|
|
||||||
LODR_lm_scale: float,
|
|
||||||
LM: LmScorer,
|
|
||||||
beam: int = 4,
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
|
||||||
`modified_beam_search`. It uses a bi-gram language model as the estimate
|
|
||||||
of the internal language model and subtracts its score during shallow fusion
|
|
||||||
with an external language model. This implementation uses a RNNLM as the
|
|
||||||
external language model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (SURT):
|
|
||||||
The SURT model
|
|
||||||
encoder_out (torch.Tensor):
|
|
||||||
Encoder output in (N,T,C)
|
|
||||||
encoder_out_lens (torch.Tensor):
|
|
||||||
A 1-D tensor of shape (N,), containing the number of
|
|
||||||
valid frames in encoder_out before padding.
|
|
||||||
LODR_lm:
|
|
||||||
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
|
||||||
LODR_lm_scale:
|
|
||||||
The scale of the LODR_lm
|
|
||||||
LM:
|
|
||||||
A neural net LM, e.g an RNNLM or transformer LM
|
|
||||||
beam (int, optional):
|
|
||||||
Beam size. Defaults to 4.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
|
||||||
for the i-th utterance.
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
|
||||||
assert LM is not None
|
|
||||||
lm_scale = LM.lm_scale
|
|
||||||
|
|
||||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
|
||||||
input=encoder_out,
|
|
||||||
lengths=encoder_out_lens.cpu(),
|
|
||||||
batch_first=True,
|
|
||||||
enforce_sorted=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
|
||||||
sos_id = getattr(LM, "sos_id", 1)
|
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
|
||||||
context_size = model.decoder.context_size
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
|
|
||||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
|
||||||
N = encoder_out.size(0)
|
|
||||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
|
||||||
assert N == batch_size_list[0], (N, batch_size_list)
|
|
||||||
|
|
||||||
# get initial lm score and lm state by scoring the "sos" token
|
|
||||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
|
||||||
lens = torch.tensor([1]).to(device)
|
|
||||||
init_score, init_states = LM.score_token(sos_token, lens)
|
|
||||||
|
|
||||||
B = [HypothesisList() for _ in range(N)]
|
|
||||||
for i in range(N):
|
|
||||||
B[i].add(
|
|
||||||
Hypothesis(
|
|
||||||
ys=[blank_id] * context_size,
|
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
|
||||||
state=init_states, # state of the NN LM
|
|
||||||
lm_score=init_score.reshape(-1),
|
|
||||||
state_cost=NgramLmStateCost(
|
|
||||||
LODR_lm
|
|
||||||
), # state of the source domain ngram
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
finalized_B = []
|
|
||||||
for batch_size in batch_size_list:
|
|
||||||
start = offset
|
|
||||||
end = offset + batch_size
|
|
||||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
|
||||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
|
||||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
|
||||||
offset = end
|
|
||||||
|
|
||||||
finalized_B = B[batch_size:] + finalized_B
|
|
||||||
B = B[:batch_size]
|
|
||||||
|
|
||||||
hyps_shape = get_hyps_shape(B).to(device)
|
|
||||||
|
|
||||||
A = [list(b) for b in B]
|
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
|
||||||
|
|
||||||
ys_log_probs = torch.cat(
|
|
||||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
) # (num_hyps, context_size)
|
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
|
||||||
|
|
||||||
current_encoder_out = torch.index_select(
|
|
||||||
current_encoder_out,
|
|
||||||
dim=0,
|
|
||||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
|
||||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
|
||||||
|
|
||||||
logits = model.joiner(
|
|
||||||
current_encoder_out,
|
|
||||||
decoder_out,
|
|
||||||
project_input=False,
|
|
||||||
) # (num_hyps, 1, 1, vocab_size)
|
|
||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
|
||||||
|
|
||||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
|
||||||
|
|
||||||
log_probs.add_(ys_log_probs)
|
|
||||||
|
|
||||||
vocab_size = log_probs.size(-1)
|
|
||||||
|
|
||||||
log_probs = log_probs.reshape(-1)
|
|
||||||
|
|
||||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
|
||||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
|
||||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
|
||||||
)
|
|
||||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
|
||||||
"""
|
|
||||||
for all hyps with a non-blank new token, score this token.
|
|
||||||
It is a little confusing here because this for-loop
|
|
||||||
looks very similar to the one below. Here, we go through all
|
|
||||||
top-k tokens and only add the non-blanks ones to the token_list.
|
|
||||||
LM will score those tokens given the LM states. Note that
|
|
||||||
the variable `scores` is the LM score after seeing the new
|
|
||||||
non-blank token.
|
|
||||||
"""
|
|
||||||
token_list = []
|
|
||||||
hs = []
|
|
||||||
cs = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
|
||||||
for k in range(len(topk_hyp_indexes)):
|
|
||||||
hyp_idx = topk_hyp_indexes[k]
|
|
||||||
hyp = A[i][hyp_idx]
|
|
||||||
|
|
||||||
new_token = topk_token_indexes[k]
|
|
||||||
if new_token not in (blank_id, unk_id):
|
|
||||||
if LM.lm_type == "rnn":
|
|
||||||
token_list.append([new_token])
|
|
||||||
# store the LSTM states
|
|
||||||
hs.append(hyp.state[0])
|
|
||||||
cs.append(hyp.state[1])
|
|
||||||
else:
|
|
||||||
# for transformer LM
|
|
||||||
token_list.append(
|
|
||||||
[sos_id] + hyp.ys[context_size:] + [new_token]
|
|
||||||
)
|
|
||||||
|
|
||||||
# forward NN LM to get new states and scores
|
|
||||||
if len(token_list) != 0:
|
|
||||||
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
|
||||||
if LM.lm_type == "rnn":
|
|
||||||
tokens_to_score = (
|
|
||||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
|
||||||
)
|
|
||||||
hs = torch.cat(hs, dim=1).to(device)
|
|
||||||
cs = torch.cat(cs, dim=1).to(device)
|
|
||||||
state = (hs, cs)
|
|
||||||
else:
|
|
||||||
# for transformer LM
|
|
||||||
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
|
||||||
tokens_to_score = (
|
|
||||||
torch.nn.utils.rnn.pad_sequence(
|
|
||||||
tokens_list, batch_first=True, padding_value=0.0
|
|
||||||
)
|
|
||||||
.to(device)
|
|
||||||
.to(torch.int64)
|
|
||||||
)
|
|
||||||
|
|
||||||
state = None
|
|
||||||
|
|
||||||
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
|
||||||
|
|
||||||
count = 0 # index, used to locate score and lm states
|
|
||||||
for i in range(batch_size):
|
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
|
||||||
|
|
||||||
for k in range(len(topk_hyp_indexes)):
|
|
||||||
hyp_idx = topk_hyp_indexes[k]
|
|
||||||
hyp = A[i][hyp_idx]
|
|
||||||
|
|
||||||
ys = hyp.ys[:]
|
|
||||||
|
|
||||||
# current score of hyp
|
|
||||||
lm_score = hyp.lm_score
|
|
||||||
state = hyp.state
|
|
||||||
|
|
||||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
|
||||||
new_token = topk_token_indexes[k]
|
|
||||||
if new_token not in (blank_id, unk_id):
|
|
||||||
|
|
||||||
ys.append(new_token)
|
|
||||||
state_cost = hyp.state_cost.forward_one_step(new_token)
|
|
||||||
|
|
||||||
# calculate the score of the latest token
|
|
||||||
current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score
|
|
||||||
|
|
||||||
assert current_ngram_score <= 0.0, (
|
|
||||||
state_cost.lm_score,
|
|
||||||
hyp.state_cost.lm_score,
|
|
||||||
)
|
|
||||||
# score = score + TDLM_score - LODR_score
|
|
||||||
# LODR_LM_scale should be a negative number here
|
|
||||||
hyp_log_prob += (
|
|
||||||
lm_score[new_token] * lm_scale
|
|
||||||
+ LODR_lm_scale * current_ngram_score
|
|
||||||
) # add the lm score
|
|
||||||
|
|
||||||
lm_score = scores[count]
|
|
||||||
if LM.lm_type == "rnn":
|
|
||||||
state = (
|
|
||||||
lm_states[0][:, count, :].unsqueeze(1),
|
|
||||||
lm_states[1][:, count, :].unsqueeze(1),
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
else:
|
|
||||||
state_cost = hyp.state_cost
|
|
||||||
|
|
||||||
new_hyp = Hypothesis(
|
|
||||||
ys=ys,
|
|
||||||
log_prob=hyp_log_prob,
|
|
||||||
state=state,
|
|
||||||
lm_score=lm_score,
|
|
||||||
state_cost=state_cost,
|
|
||||||
)
|
|
||||||
B[i].add(new_hyp)
|
|
||||||
|
|
||||||
B = B + finalized_B
|
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
|
||||||
ans = []
|
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
|
||||||
for i in range(N):
|
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
model: SURT,
|
model: SURT,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
|||||||
@ -42,7 +42,6 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import chain, groupby, repeat
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -53,12 +52,9 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriCssAsrDataModule
|
from asr_datamodule import LibriCssAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest_LG,
|
|
||||||
fast_beam_search_one_best,
|
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
modified_beam_search_LODR,
|
|
||||||
)
|
)
|
||||||
from lhotse.utils import EPSILON
|
from lhotse.utils import EPSILON
|
||||||
from train import add_model_arguments, get_params, get_surt_model
|
from train import add_model_arguments, get_params, get_surt_model
|
||||||
@ -155,9 +151,6 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -170,47 +163,6 @@ def get_parser():
|
|||||||
modified_beam_search.""",
|
modified_beam_search.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam",
|
|
||||||
type=float,
|
|
||||||
default=20.0,
|
|
||||||
help="""A floating point value to calculate the cutoff score during beam
|
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
|
||||||
`beam` in Kaldi.
|
|
||||||
Used only when --decoding-method is fast_beam_search,
|
|
||||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
|
||||||
and fast_beam_search_nbest_oracle
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ngram-lm-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.01,
|
|
||||||
help="""
|
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
|
||||||
It specifies the scale for n-gram LM scores.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-contexts",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
|
||||||
and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-states",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
|
||||||
and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -225,24 +177,6 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-paths",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="""Number of paths for nbest decoding.
|
|
||||||
Used only when the decoding method is fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--nbest-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="""Scale applied to lattice scores when computing nbest paths.
|
|
||||||
Used only when the decoding method is fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-masks",
|
"--save-masks",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -260,11 +194,6 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
ngram_lm: Optional[NgramLm] = None,
|
|
||||||
ngram_lm_scale: float = 1.0,
|
|
||||||
LM: Optional[LmScorer] = None,
|
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -287,12 +216,6 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -348,33 +271,7 @@ def decode_one_batch(
|
|||||||
return out_hyps
|
return out_hyps
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp)
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp)
|
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -391,18 +288,6 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp)
|
hyps.append(hyp)
|
||||||
elif params.decoding_method == "modified_beam_search_LODR":
|
|
||||||
hyp_tokens = modified_beam_search_LODR(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam_size,
|
|
||||||
LODR_lm=ngram_lm,
|
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
|
||||||
LM=LM,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp)
|
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -430,17 +315,6 @@ def decode_one_batch(
|
|||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": _group_channels(hyps)}, masks_dict
|
return {"greedy_search": _group_channels(hyps)}, masks_dict
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
|
||||||
key = f"beam_{params.beam}_"
|
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
|
||||||
key += f"max_states_{params.max_states}"
|
|
||||||
if "nbest" in params.decoding_method:
|
|
||||||
key += f"_num_paths_{params.num_paths}_"
|
|
||||||
key += f"nbest_scale_{params.nbest_scale}"
|
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
|
||||||
|
|
||||||
return {key: _group_channels(hyps)}, masks_dict
|
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}, masks_dict
|
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}, masks_dict
|
||||||
|
|
||||||
@ -450,11 +324,6 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
ngram_lm: Optional[NgramLm] = None,
|
|
||||||
ngram_lm_scale: float = 1.0,
|
|
||||||
LM: Optional[LmScorer] = None,
|
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -467,12 +336,6 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -502,12 +365,6 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
word_table=word_table,
|
|
||||||
batch=batch,
|
|
||||||
ngram_lm=ngram_lm,
|
|
||||||
ngram_lm_scale=ngram_lm_scale,
|
|
||||||
LM=LM,
|
|
||||||
)
|
)
|
||||||
masks.update(masks_dict)
|
masks.update(masks_dict)
|
||||||
|
|
||||||
@ -607,12 +464,7 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
|
||||||
"fast_beam_search_nbest",
|
|
||||||
"fast_beam_search_nbest_LG",
|
|
||||||
"fast_beam_search_nbest_oracle",
|
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
"modified_beam_search_LODR",
|
|
||||||
), f"Decoding method {params.decoding_method} is not supported."
|
), f"Decoding method {params.decoding_method} is not supported."
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -621,16 +473,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
|
||||||
if "nbest" in params.decoding_method:
|
|
||||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
|
||||||
params.suffix += f"-num-paths-{params.num_paths}"
|
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
|
||||||
elif "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
@ -639,11 +482,6 @@ def main():
|
|||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
if "LODR" in params.decoding_method:
|
|
||||||
params.suffix += (
|
|
||||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -750,52 +588,6 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
word_table = lexicon.word_table
|
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
|
||||||
logging.info(f"Loading {lg_filename}")
|
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
|
||||||
torch.load(lg_filename, map_location=device)
|
|
||||||
)
|
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
|
||||||
else:
|
|
||||||
word_table = None
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
|
||||||
else:
|
|
||||||
decoding_graph = None
|
|
||||||
word_table = None
|
|
||||||
|
|
||||||
# only load N-gram LM when needed
|
|
||||||
if "LODR" in params.decoding_method:
|
|
||||||
lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt"
|
|
||||||
logging.info(f"lm filename: {lm_filename}")
|
|
||||||
ngram_lm = NgramLm(
|
|
||||||
lm_filename,
|
|
||||||
backoff_id=params.backoff_id,
|
|
||||||
is_binary=False,
|
|
||||||
)
|
|
||||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
|
||||||
ngram_lm_scale = params.ngram_lm_scale
|
|
||||||
else:
|
|
||||||
ngram_lm = None
|
|
||||||
ngram_lm_scale = None
|
|
||||||
|
|
||||||
# only load the neural network LM if doing shallow fusion
|
|
||||||
if params.use_shallow_fusion:
|
|
||||||
LM = LmScorer(
|
|
||||||
lm_type=params.lm_type,
|
|
||||||
params=params,
|
|
||||||
device=device,
|
|
||||||
lm_scale=params.lm_scale,
|
|
||||||
)
|
|
||||||
LM.to(device)
|
|
||||||
LM.eval()
|
|
||||||
|
|
||||||
else:
|
|
||||||
LM = None
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -817,11 +609,6 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
ngram_lm=ngram_lm,
|
|
||||||
ngram_lm_scale=ngram_lm_scale,
|
|
||||||
LM=LM,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
@ -844,11 +631,6 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
ngram_lm=ngram_lm,
|
|
||||||
ngram_lm_scale=ngram_lm_scale,
|
|
||||||
LM=LM,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user