mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
support RNNLM shallow fusion for LSTM transducer
This commit is contained in:
parent
d389524d45
commit
de2f5e3e6d
@ -115,7 +115,8 @@ from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_ngram_rescoring,
|
||||
modified_beam_search_rnnlm_shallow_fusion,
|
||||
|
||||
)
|
||||
from librispeech import LibriSpeech
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
@ -128,6 +129,7 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
@ -216,7 +218,7 @@ def get_parser():
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_ngram_rescoring
|
||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
@ -307,21 +309,74 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Token Ngram used for rescoring.
|
||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||
"--rnn-lm-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""Used only when --method is modified_beam_search3.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""ID of the backoff symbol.
|
||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ilm-scale",
|
||||
type=float,
|
||||
default=-0.1
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -336,6 +391,8 @@ def decode_one_batch(
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -469,14 +526,14 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
|
||||
hyp_tokens = modified_beam_search_ngram_rescoring(
|
||||
elif params.decoding_method == "modified_beam_search_sf_rnnlm":
|
||||
hyp_tokens = modified_beam_search_sf_rnnlm_batched(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -531,7 +588,9 @@ def decode_dataset(
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
rnnlm: Optional[NgramLm] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -572,6 +631,9 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||
|
||||
logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}")
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -582,6 +644,8 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -607,7 +671,7 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
@ -667,7 +731,7 @@ def main():
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_ngram_rescoring",
|
||||
"modified_beam_search_sf_rnnlm",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -692,7 +756,12 @@ def main():
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
|
||||
if "rnnlm" in params.decoding_method:
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||
|
||||
if "ILME" in params.decoding_method:
|
||||
params.suffix += f"-ILME-scale={params.ilm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -806,14 +875,28 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
# only load rnnlm if used
|
||||
if "rnnlm" in params.decoding_method:
|
||||
rnn_lm_scale = params.rnn_lm_scale
|
||||
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
assert params.rnn_lm_avg == 1
|
||||
|
||||
load_checkpoint(
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
rnn_lm_model,
|
||||
)
|
||||
rnn_lm_model.to(device)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
else:
|
||||
rnn_lm_model = None
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
@ -861,6 +944,8 @@ def main():
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
rnnlm=rnn_lm_model,
|
||||
rnnlm_scale=rnn_lm_scale,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
@ -25,13 +25,8 @@ from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import (
|
||||
DecodingResults,
|
||||
add_eos,
|
||||
add_sos,
|
||||
get_texts,
|
||||
get_texts_with_timestamp,
|
||||
)
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import add_eos, add_sos, get_texts
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
@ -43,8 +38,7 @@ def fast_beam_search_one_best(
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using fast beam search, and then
|
||||
@ -68,12 +62,8 @@ def fast_beam_search_one_best(
|
||||
Max contexts pre stream per frame.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return the decoded result.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -87,11 +77,8 @@ def fast_beam_search_one_best(
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
|
||||
def fast_beam_search_nbest_LG(
|
||||
@ -106,8 +93,7 @@ def fast_beam_search_nbest_LG(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
The process to get the results is:
|
||||
@ -144,12 +130,8 @@ def fast_beam_search_nbest_LG(
|
||||
single precision.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return the decoded result.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -214,10 +196,9 @@ def fast_beam_search_nbest_LG(
|
||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
def fast_beam_search_nbest(
|
||||
@ -232,8 +213,7 @@ def fast_beam_search_nbest(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
The process to get the results is:
|
||||
@ -270,12 +250,8 @@ def fast_beam_search_nbest(
|
||||
single precision.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return the decoded result.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -304,10 +280,9 @@ def fast_beam_search_nbest(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
def fast_beam_search_nbest_oracle(
|
||||
@ -323,8 +298,7 @@ def fast_beam_search_nbest_oracle(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using fast beam search, and then
|
||||
@ -365,12 +339,8 @@ def fast_beam_search_nbest_oracle(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return the decoded result.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -409,10 +379,8 @@ def fast_beam_search_nbest_oracle(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
@ -502,11 +470,8 @@ def fast_beam_search(
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
max_sym_per_frame: int,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||
) -> List[int]:
|
||||
"""Greedy search for a single utterance.
|
||||
Args:
|
||||
model:
|
||||
@ -516,12 +481,8 @@ def greedy_search(
|
||||
max_sym_per_frame:
|
||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||
would be 100%.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -547,10 +508,6 @@ def greedy_search(
|
||||
t = 0
|
||||
hyp = [blank_id] * context_size
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which hyp[i] is decoded
|
||||
timestamp = []
|
||||
|
||||
# Maximum symbols per utterance.
|
||||
max_sym_per_utt = 1000
|
||||
|
||||
@ -577,7 +534,6 @@ def greedy_search(
|
||||
y = logits.argmax().item()
|
||||
if y not in (blank_id, unk_id):
|
||||
hyp.append(y)
|
||||
timestamp.append(t)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
@ -592,21 +548,14 @@ def greedy_search(
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return hyp
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=[hyp],
|
||||
timestamps=[timestamp],
|
||||
)
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
@ -616,12 +565,9 @@ def greedy_search_batch(
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return a list-of-list of token IDs containing the decoded results.
|
||||
len(ans) equals to encoder_out.size(0).
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
@ -646,10 +592,6 @@ def greedy_search_batch(
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
# timestamp[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
timestamps = [[] for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
@ -663,7 +605,7 @@ def greedy_search_batch(
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -685,7 +627,6 @@ def greedy_search_batch(
|
||||
for i, v in enumerate(y):
|
||||
if v not in (blank_id, unk_id):
|
||||
hyps[i].append(v)
|
||||
timestamps[i].append(t)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
@ -700,19 +641,11 @@ def greedy_search_batch(
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -725,11 +658,9 @@ class Hypothesis:
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int]
|
||||
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
state: Optional = None
|
||||
lm_score: Optional=None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
@ -878,8 +809,7 @@ def modified_beam_search(
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
) -> List[List[int]]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
@ -894,12 +824,9 @@ def modified_beam_search(
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
@ -927,7 +854,6 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
@ -935,7 +861,7 @@ def modified_beam_search(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -1013,44 +939,30 @@ def modified_beam_search(
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
It decodes only one utterance at a time. We keep it only for reference.
|
||||
@ -1065,13 +977,8 @@ def _deprecated_modified_beam_search(
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return the decoded result.
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
@ -1091,7 +998,6 @@ def _deprecated_modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -1150,24 +1056,17 @@ def _deprecated_modified_beam_search(
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def beam_search(
|
||||
@ -1175,8 +1074,7 @@ def beam_search(
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
) -> List[int]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
@ -1191,13 +1089,8 @@ def beam_search(
|
||||
Beam size.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -1224,7 +1117,7 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -1285,13 +1178,7 @@ def beam_search(
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=y_star.ys[:],
|
||||
log_prob=new_y_star_log_prob,
|
||||
timestamp=y_star.timestamp[:],
|
||||
)
|
||||
)
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
@ -1300,14 +1187,7 @@ def beam_search(
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
new_timestamp = y_star.timestamp + [t]
|
||||
A.add(
|
||||
Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
)
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
@ -1323,11 +1203,7 @@ def beam_search(
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def fast_beam_search_with_nbest_rescoring(
|
||||
@ -1347,8 +1223,7 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model. The shortest path within the
|
||||
@ -1390,13 +1265,10 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result in a dict, where the key has the form
|
||||
'ngram_lm_scale_xx' and the value is the decoded results
|
||||
optionally with timestamps. `xx` is the ngram LM scale value
|
||||
used during decoding, i.e., 0.1.
|
||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||
ngram LM scale value used during decoding, i.e., 0.1.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -1474,18 +1346,16 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
log_semiring=False,
|
||||
)
|
||||
|
||||
ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
|
||||
ans: Dict[str, List[List[int]]] = {}
|
||||
for s in ngram_lm_scale_list:
|
||||
key = f"ngram_lm_scale_{s}"
|
||||
tot_scores = am_scores.values + s * ngram_lm_scores
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
if not return_timestamps:
|
||||
ans[key] = get_texts(best_path)
|
||||
else:
|
||||
ans[key] = get_texts_with_timestamp(best_path)
|
||||
ans[key] = hyps
|
||||
|
||||
return ans
|
||||
|
||||
@ -1509,8 +1379,7 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model and a rnn-lm.
|
||||
@ -1556,13 +1425,10 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result in a dict, where the key has the form
|
||||
'ngram_lm_scale_xx' and the value is the decoded results
|
||||
optionally with timestamps. `xx` is the ngram LM scale value
|
||||
used during decoding, i.e., 0.1.
|
||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||
ngram LM scale value used during decoding, i.e., 0.1.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -1674,44 +1540,150 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
if not return_timestamps:
|
||||
ans[key] = get_texts(best_path)
|
||||
else:
|
||||
ans[key] = get_texts_with_timestamp(best_path)
|
||||
ans[key] = hyps
|
||||
|
||||
return ans
|
||||
|
||||
def modified_beam_search_sf_rnnlm(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
sp,
|
||||
rnnlm: RnnLmModel,
|
||||
rnnlm_scale: float,
|
||||
beam: int = 4,
|
||||
):
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
lm_scale = rnnlm_scale
|
||||
|
||||
def modified_beam_search_ngram_rescoring(
|
||||
assert rnnlm is not None
|
||||
assert encoder_out.ndim == 2, encoder_out.shape
|
||||
rnnlm.clean_cache()
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = sp.piece_to_id("<sos/eos>")
|
||||
eos_id = sp.piece_to_id("<sos/eos>")
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
T = encoder_out.shape[0]
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[t : t + 1]
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyp in A]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyp in A],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
# decoder_out is of shape (num_hyps, joiner_dim)
|
||||
current_encoder_out = current_encoder_out.repeat(len(A), 1)
|
||||
# current_encoder_out is of shape (num_hyps, encoder_out_dim)
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
log_probs = log_probs.reshape(-1)
|
||||
topk_log_probs, topk_indexes = log_probs.topk(
|
||||
beam
|
||||
) # get topk tokens and scores
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[hyp_idx] # get hyp
|
||||
new_ys = hyp.ys[:]
|
||||
state = "ys=" + "+".join(list(map(str, new_ys)))
|
||||
tokens = k2.RaggedTensor([new_ys[context_size:]])
|
||||
|
||||
lm_score = rnnlm.predict(
|
||||
tokens, state, sos_id, eos_id, blank_id
|
||||
) # get rnnlm score
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k] # get token
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
# state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||
hyp_log_prob += (
|
||||
lm_score[new_token] * lm_scale
|
||||
) # add the lm score
|
||||
else:
|
||||
new_ys = new_ys
|
||||
new_log_prob = hyp_log_prob
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[context_size:]
|
||||
|
||||
def modified_beam_search_rnnlm_shallow_fusion(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ngram_lm: NgramLm,
|
||||
ngram_lm_scale: float,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
rnnlm: RnnLmModel,
|
||||
rnnlm_scale: float,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
"""Modified_beam_search + RNNLM shallow fusion
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
model (Transducer):
|
||||
The transducer model
|
||||
encoder_out (torch.Tensor):
|
||||
Encoder output in (N,T,C)
|
||||
encoder_out_lens (torch.Tensor):
|
||||
A 1-D tensor of shape (N,), containing the number of
|
||||
valid frames in encoder_out before padding.
|
||||
sp:
|
||||
Sentence piece generator.
|
||||
rnnlm (RnnLmModel):
|
||||
RNNLM
|
||||
rnnlm_scale (float):
|
||||
scale of RNNLM in shallow fusion
|
||||
beam (int, optional):
|
||||
Beam size. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
assert rnnlm is not None
|
||||
lm_scale = rnnlm_scale
|
||||
vocab_size = rnnlm.vocab_size
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
@ -1721,26 +1693,33 @@ def modified_beam_search_ngram_rescoring(
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = sp.piece_to_id("<sos/eos>")
|
||||
eos_id = sp.piece_to_id("<sos/eos>")
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
lm_scale = ngram_lm_scale
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
# get initial lm score and lm state by scoring the "sos" token
|
||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||
init_score, init_states = rnnlm.score_token(sos_token)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state_cost=NgramLmStateCost(ngram_lm),
|
||||
state=init_states,
|
||||
lm_score=init_score.reshape(-1)
|
||||
)
|
||||
)
|
||||
|
||||
rnnlm.clean_cache()
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
@ -1748,7 +1727,7 @@ def modified_beam_search_ngram_rescoring(
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
@ -1762,12 +1741,8 @@ def modified_beam_search_ngram_rescoring(
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[
|
||||
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
|
||||
for hyps in A
|
||||
for hyp in hyps
|
||||
]
|
||||
) # (num_hyps, 1)
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
@ -1777,10 +1752,7 @@ def modified_beam_search_ngram_rescoring(
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
@ -1795,12 +1767,14 @@ def modified_beam_search_ngram_rescoring(
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = (logits / temperature).log_softmax(
|
||||
log_probs = logits.log_softmax(
|
||||
dim=-1
|
||||
) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
@ -1811,6 +1785,38 @@ def modified_beam_search_ngram_rescoring(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
|
||||
# for all hyps with a non-blank new token, score it
|
||||
token_list = []
|
||||
hs = []
|
||||
cs = []
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
assert new_token != 0, new_token
|
||||
token_list.append([new_token])
|
||||
hs.append(hyp.state[0])
|
||||
cs.append(hyp.state[1])
|
||||
# forward RNNLM to get new states and scores
|
||||
if len(token_list) != 0:
|
||||
tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1)
|
||||
|
||||
hs = torch.cat(hs, dim=1).to(device)
|
||||
cs = torch.cat(cs, dim=1).to(device)
|
||||
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs))
|
||||
|
||||
count = 0 # index, used to locate score and lm states
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
@ -1823,21 +1829,29 @@ def modified_beam_search_ngram_rescoring(
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
ys = hyp.ys[:]
|
||||
|
||||
lm_score = hyp.lm_score
|
||||
state = hyp.state
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||
else:
|
||||
state_cost = hyp.state_cost
|
||||
|
||||
# We only keep AM scores in new_hyp.log_prob
|
||||
new_log_prob = (
|
||||
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
|
||||
)
|
||||
ys.append(new_token)
|
||||
hyp_log_prob += (
|
||||
lm_score[new_token] * lm_scale
|
||||
) # add the lm score
|
||||
|
||||
lm_score = scores[count]
|
||||
state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1))
|
||||
count += 1
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
|
||||
ys=ys,
|
||||
log_prob=hyp_log_prob,
|
||||
state=state,
|
||||
lm_score=lm_score
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
|
@ -18,8 +18,9 @@ import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import k2
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.utils import add_eos, add_sos, make_pad_mask
|
||||
|
||||
|
||||
class RnnLmModel(torch.nn.Module):
|
||||
@ -72,6 +73,8 @@ class RnnLmModel(torch.nn.Module):
|
||||
else:
|
||||
logging.info("Not tying weights")
|
||||
|
||||
self.cache = {}
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@ -118,3 +121,124 @@ class RnnLmModel(torch.nn.Module):
|
||||
nll_loss = nll_loss.reshape(batch_size, -1)
|
||||
|
||||
return nll_loss
|
||||
|
||||
def get_init_states(self, sos):
|
||||
p = next(self.parameters())
|
||||
|
||||
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
|
||||
device = next(self.parameters()).device
|
||||
batch_size = len(token_lens)
|
||||
|
||||
sos_tokens = add_sos(tokens, sos_id)
|
||||
tokens_eos = add_eos(tokens, eos_id)
|
||||
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
|
||||
|
||||
sentence_lengths = (
|
||||
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
|
||||
)
|
||||
|
||||
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
|
||||
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
x_tokens = x_tokens.to(torch.int64).to(device)
|
||||
y_tokens = y_tokens.to(torch.int64).to(device)
|
||||
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
|
||||
|
||||
embedding = self.input_embedding(x_tokens)
|
||||
|
||||
# Note: We use batch_first==True
|
||||
rnn_out, states = self.rnn(embedding)
|
||||
logits = self.output_linear(rnn_out)
|
||||
mask = torch.zeros(logits.shape).bool().to(device)
|
||||
for i in range(batch_size):
|
||||
mask[i, token_lens[i], :] = True
|
||||
logits = logits[mask].reshape(batch_size, -1)
|
||||
|
||||
return logits[:,:].log_softmax(-1), states
|
||||
|
||||
def clean_cache(self):
|
||||
self.cache = {}
|
||||
|
||||
def score_token(self, tokens: torch.Tensor, state=None):
|
||||
device = next(self.parameters()).device
|
||||
batch_size = tokens.size(0)
|
||||
if state:
|
||||
h,c = state
|
||||
else:
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
|
||||
|
||||
embedding = self.input_embedding(tokens)
|
||||
rnn_out, states = self.rnn(embedding, (h,c))
|
||||
logits = self.output_linear(rnn_out)
|
||||
|
||||
return logits[:,0].log_softmax(-1), states
|
||||
|
||||
def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None):
|
||||
batch_size = len(token_lens)
|
||||
if state:
|
||||
h,c = state
|
||||
else:
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
sos_tokens = add_sos(tokens, sos_id)
|
||||
tokens_eos = add_eos(tokens, eos_id)
|
||||
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
|
||||
|
||||
sentence_lengths = (
|
||||
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
|
||||
)
|
||||
|
||||
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
|
||||
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
x_tokens = x_tokens.to(torch.int64).to(device)
|
||||
y_tokens = y_tokens.to(torch.int64).to(device)
|
||||
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
|
||||
|
||||
embedding = self.input_embedding(x_tokens)
|
||||
|
||||
# Note: We use batch_first==True
|
||||
rnn_out, states = self.rnn(embedding, (h,c))
|
||||
logits = self.output_linear(rnn_out)
|
||||
|
||||
return logits, states
|
||||
|
||||
if __name__=="__main__":
|
||||
LM = RnnLmModel(500, 2048, 2048, 3, True)
|
||||
h0 = torch.zeros(3, 1, 2048)
|
||||
c0 = torch.zeros(3, 1, 2048)
|
||||
seq = [[0,1,2,3]]
|
||||
seq_lens = [len(s) for s in seq]
|
||||
tokens = k2.RaggedTensor(seq)
|
||||
output1, state = LM.forward_with_state(
|
||||
tokens,
|
||||
seq_lens,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
state=(h0,c0)
|
||||
)
|
||||
seq = [[0,1,2,3,4]]
|
||||
seq_lens = [len(s) for s in seq]
|
||||
tokens = k2.RaggedTensor(seq)
|
||||
output2, _ = LM.forward_with_state(
|
||||
tokens,
|
||||
seq_lens,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
state=(h0,c0)
|
||||
)
|
||||
|
||||
seq = [[4]]
|
||||
seq_lens = [len(s) for s in seq]
|
||||
output3 = LM.score_token(seq, seq_lens, state)
|
||||
|
||||
print("Finished")
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user