support RNNLM shallow fusion for LSTM transducer

This commit is contained in:
marcoyang 2022-11-02 16:15:56 +08:00
parent d389524d45
commit de2f5e3e6d
3 changed files with 503 additions and 280 deletions

View File

@ -115,7 +115,8 @@ from beam_search import (
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_ngram_rescoring, modified_beam_search_rnnlm_shallow_fusion,
) )
from librispeech import LibriSpeech from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -128,6 +129,7 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -216,7 +218,7 @@ def get_parser():
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
""", """,
@ -307,21 +309,74 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--tokens-ngram", "--rnn-lm-scale",
type=int, type=float,
default=3, default=0.0,
help="""Token Ngram used for rescoring. help="""Used only when --method is modified_beam_search3.
Used only when the decoding method is modified_beam_search_ngram_rescoring""", It specifies the path to RNN LM exp dir.
""",
) )
parser.add_argument( parser.add_argument(
"--backoff-id", "--rnn-lm-exp-dir",
type=int, type=str,
default=500, default="rnn_lm/exp",
help="""ID of the backoff symbol. help="""Used only when --method is rnn-lm.
Used only when the decoding method is modified_beam_search_ngram_rescoring""", It specifies the path to RNN LM exp dir.
""",
) )
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--ilm-scale",
type=float,
default=-0.1
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -336,6 +391,8 @@ def decode_one_batch(
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -469,14 +526,14 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_ngram_rescoring": elif params.decoding_method == "modified_beam_search_sf_rnnlm":
hyp_tokens = modified_beam_search_ngram_rescoring( hyp_tokens = modified_beam_search_sf_rnnlm_batched(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
ngram_lm=ngram_lm, sp=sp,
ngram_lm_scale=ngram_lm_scale, rnnlm=rnnlm,
beam=params.beam_size, rnnlm_scale=rnnlm_scale,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -531,7 +588,9 @@ def decode_dataset(
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: rnnlm: Optional[NgramLm] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -572,6 +631,9 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]])
logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}")
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -582,6 +644,8 @@ def decode_dataset(
batch=batch, batch=batch,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -607,7 +671,7 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
@ -667,7 +731,7 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_ngram_rescoring", "modified_beam_search_sf_rnnlm",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -692,7 +756,12 @@ def main():
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if "ILME" in params.decoding_method:
params.suffix += f"-ILME-scale={params.ilm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -806,14 +875,28 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
lm_filename = f"{params.tokens_ngram}gram.fst.txt" # only load rnnlm if used
logging.info(f"lm filename: {lm_filename}") if "rnnlm" in params.decoding_method:
ngram_lm = NgramLm( rnn_lm_scale = params.rnn_lm_scale
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id, rnn_lm_model = RnnLmModel(
is_binary=False, vocab_size=params.vocab_size,
) embedding_dim=params.rnn_lm_embedding_dim,
logging.info(f"num states: {ngram_lm.lm.num_states}") hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
else:
rnn_lm_model = None
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
@ -861,6 +944,8 @@ def main():
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale, ngram_lm_scale=params.ngram_lm_scale,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
) )
save_results( save_results(

View File

@ -16,7 +16,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -25,13 +25,8 @@ from model import Transducer
from icefall import NgramLm, NgramLmStateCost from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.utils import ( from icefall.rnn_lm.model import RnnLmModel
DecodingResults, from icefall.utils import add_eos, add_sos, get_texts
add_eos,
add_sos,
get_texts,
get_texts_with_timestamp,
)
def fast_beam_search_one_best( def fast_beam_search_one_best(
@ -43,8 +38,7 @@ def fast_beam_search_one_best(
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> List[List[int]]:
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then A lattice is first obtained using fast beam search, and then
@ -68,12 +62,8 @@ def fast_beam_search_one_best(
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -87,11 +77,8 @@ def fast_beam_search_one_best(
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
if not return_timestamps: return hyps
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_LG( def fast_beam_search_nbest_LG(
@ -106,8 +93,7 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> List[List[int]]:
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
The process to get the results is: The process to get the results is:
@ -144,12 +130,8 @@ def fast_beam_search_nbest_LG(
single precision. single precision.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -214,10 +196,9 @@ def fast_beam_search_nbest_LG(
best_hyp_indexes = ragged_tot_scores.argmax() best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
if not return_timestamps: hyps = get_texts(best_path)
return get_texts(best_path)
else: return hyps
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest( def fast_beam_search_nbest(
@ -232,8 +213,7 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> List[List[int]]:
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
The process to get the results is: The process to get the results is:
@ -270,12 +250,8 @@ def fast_beam_search_nbest(
single precision. single precision.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -304,10 +280,9 @@ def fast_beam_search_nbest(
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
if not return_timestamps: hyps = get_texts(best_path)
return get_texts(best_path)
else: return hyps
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_oracle( def fast_beam_search_nbest_oracle(
@ -323,8 +298,7 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> List[List[int]]:
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then A lattice is first obtained using fast beam search, and then
@ -365,12 +339,8 @@ def fast_beam_search_nbest_oracle(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -409,10 +379,8 @@ def fast_beam_search_nbest_oracle(
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
if not return_timestamps: hyps = get_texts(best_path)
return get_texts(best_path) return hyps
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search( def fast_beam_search(
@ -502,11 +470,8 @@ def fast_beam_search(
def greedy_search( def greedy_search(
model: Transducer, model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
encoder_out: torch.Tensor, ) -> List[int]:
max_sym_per_frame: int,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance. """Greedy search for a single utterance.
Args: Args:
model: model:
@ -516,12 +481,8 @@ def greedy_search(
max_sym_per_frame: max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%. would be 100%.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -547,10 +508,6 @@ def greedy_search(
t = 0 t = 0
hyp = [blank_id] * context_size hyp = [blank_id] * context_size
# timestamp[i] is the frame index after subsampling
# on which hyp[i] is decoded
timestamp = []
# Maximum symbols per utterance. # Maximum symbols per utterance.
max_sym_per_utt = 1000 max_sym_per_utt = 1000
@ -577,7 +534,6 @@ def greedy_search(
y = logits.argmax().item() y = logits.argmax().item()
if y not in (blank_id, unk_id): if y not in (blank_id, unk_id):
hyp.append(y) hyp.append(y)
timestamp.append(t)
decoder_input = torch.tensor( decoder_input = torch.tensor(
[hyp[-context_size:]], device=device [hyp[-context_size:]], device=device
).reshape(1, context_size) ).reshape(1, context_size)
@ -592,21 +548,14 @@ def greedy_search(
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks hyp = hyp[context_size:] # remove blanks
if not return_timestamps: return hyp
return hyp
else:
return DecodingResults(
tokens=[hyp],
timestamps=[timestamp],
)
def greedy_search_batch( def greedy_search_batch(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
return_timestamps: bool = False, ) -> List[List[int]]:
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
model: model:
@ -616,12 +565,9 @@ def greedy_search_batch(
encoder_out_lens: encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding. encoder_out before padding.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return a list-of-list of token IDs containing the decoded results.
Else, return a DecodingResults object containing len(ans) equals to encoder_out.size(0).
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -646,10 +592,6 @@ def greedy_search_batch(
hyps = [[blank_id] * context_size for _ in range(N)] hyps = [[blank_id] * context_size for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
device=device, device=device,
@ -663,7 +605,7 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0 offset = 0
for (t, batch_size) in enumerate(batch_size_list): for batch_size in batch_size_list:
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] current_encoder_out = encoder_out.data[start:end]
@ -685,7 +627,6 @@ def greedy_search_batch(
for i, v in enumerate(y): for i, v in enumerate(y):
if v not in (blank_id, unk_id): if v not in (blank_id, unk_id):
hyps[i].append(v) hyps[i].append(v)
timestamps[i].append(t)
emitted = True emitted = True
if emitted: if emitted:
# update decoder output # update decoder output
@ -700,19 +641,11 @@ def greedy_search_batch(
sorted_ans = [h[context_size:] for h in hyps] sorted_ans = [h[context_size:] for h in hyps]
ans = [] ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
if not return_timestamps: return ans
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
@dataclass @dataclass
@ -725,11 +658,9 @@ class Hypothesis:
# It contains only one entry. # It contains only one entry.
log_prob: torch.Tensor log_prob: torch.Tensor
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int]
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None
state: Optional = None
lm_score: Optional=None
@property @property
def key(self) -> str: def key(self) -> str:
@ -878,8 +809,7 @@ def modified_beam_search(
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> List[List[int]]:
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args: Args:
@ -894,12 +824,9 @@ def modified_beam_search(
Number of active paths during the beam search. Number of active paths during the beam search.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return a list-of-list of token IDs. ans[i] is the decoding results
Else, return a DecodingResults object containing for the i-th utterance.
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -927,7 +854,6 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
) )
) )
@ -935,7 +861,7 @@ def modified_beam_search(
offset = 0 offset = 0
finalized_B = [] finalized_B = []
for (t, batch_size) in enumerate(batch_size_list): for batch_size in batch_size_list:
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] current_encoder_out = encoder_out.data[start:end]
@ -1013,44 +939,30 @@ def modified_beam_search(
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k] new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis( new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = [] ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps: return ans
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
def _deprecated_modified_beam_search( def _deprecated_modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
return_timestamps: bool = False, ) -> List[int]:
) -> Union[List[int], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference. It decodes only one utterance at a time. We keep it only for reference.
@ -1065,13 +977,8 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -1091,7 +998,6 @@ def _deprecated_modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
) )
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -1150,24 +1056,17 @@ def _deprecated_modified_beam_search(
for i in range(len(topk_hyp_indexes)): for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]] hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_timestamp = hyp.timestamp[:]
new_token = topk_token_indexes[i] new_token = topk_token_indexes[i]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis( new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B.add(new_hyp) B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
if not return_timestamps: return ys
return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def beam_search( def beam_search(
@ -1175,8 +1074,7 @@ def beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> List[int]:
) -> Union[List[int], DecodingResults]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1191,13 +1089,8 @@ def beam_search(
Beam size. Beam size.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
If return_timestamps is False, return the decoded result. Return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -1224,7 +1117,7 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -1285,13 +1178,7 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add( B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
timestamp=y_star.timestamp[:],
)
)
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
@ -1300,14 +1187,7 @@ def beam_search(
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v
new_timestamp = y_star.timestamp + [t] A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
A.add(
Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
)
)
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable
# than the most probable in A # than the most probable in A
@ -1323,11 +1203,7 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
if not return_timestamps:
return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def fast_beam_search_with_nbest_rescoring( def fast_beam_search_with_nbest_rescoring(
@ -1347,8 +1223,7 @@ def fast_beam_search_with_nbest_rescoring(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> Dict[str, List[List[int]]]:
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the and rescored using a given language model. The shortest path within the
@ -1390,13 +1265,10 @@ def fast_beam_search_with_nbest_rescoring(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result in a dict, where the key has the form Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
optionally with timestamps. `xx` is the ngram LM scale value ngram LM scale value used during decoding, i.e., 0.1.
used during decoding, i.e., 0.1.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -1474,18 +1346,16 @@ def fast_beam_search_with_nbest_rescoring(
log_semiring=False, log_semiring=False,
) )
ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} ans: Dict[str, List[List[int]]] = {}
for s in ngram_lm_scale_list: for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}" key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax() max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
if not return_timestamps: ans[key] = hyps
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans return ans
@ -1509,8 +1379,7 @@ def fast_beam_search_with_nbest_rnn_rescoring(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, ) -> Dict[str, List[List[int]]]:
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm. and rescored using a given language model and a rnn-lm.
@ -1556,13 +1425,10 @@ def fast_beam_search_with_nbest_rnn_rescoring(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result in a dict, where the key has the form Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
optionally with timestamps. `xx` is the ngram LM scale value ngram LM scale value used during decoding, i.e., 0.1.
used during decoding, i.e., 0.1.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -1674,44 +1540,150 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax() max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
if not return_timestamps: ans[key] = hyps
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans return ans
def modified_beam_search_sf_rnnlm(
model: Transducer,
encoder_out: torch.Tensor,
sp,
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
):
encoder_out = model.joiner.encoder_proj(encoder_out)
lm_scale = rnnlm_scale
def modified_beam_search_ngram_rescoring( assert rnnlm is not None
assert encoder_out.ndim == 2, encoder_out.shape
rnnlm.clean_cache()
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
eos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
T = encoder_out.shape[0]
for t in range(T):
current_encoder_out = encoder_out[t : t + 1]
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyp in A]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, joiner_dim)
current_encoder_out = current_encoder_out.repeat(len(A), 1)
# current_encoder_out is of shape (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(
beam
) # get topk tokens and scores
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[hyp_idx] # get hyp
new_ys = hyp.ys[:]
state = "ys=" + "+".join(list(map(str, new_ys)))
tokens = k2.RaggedTensor([new_ys[context_size:]])
lm_score = rnnlm.predict(
tokens, state, sos_id, eos_id, blank_id
) # get rnnlm score
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] # get token
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
# state_cost = hyp.state_cost.forward_one_step(new_token)
hyp_log_prob += (
lm_score[new_token] * lm_scale
) # add the lm score
else:
new_ys = new_ys
new_log_prob = hyp_log_prob
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
return best_hyp.ys[context_size:]
def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
ngram_lm: NgramLm, sp: spm.SentencePieceProcessor,
ngram_lm_scale: float, rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Modified_beam_search + RNNLM shallow fusion
Args: Args:
model: model (Transducer):
The transducer model. The transducer model
encoder_out: encoder_out (torch.Tensor):
Output from the encoder. Its shape is (N, T, C). Encoder output in (N,T,C)
encoder_out_lens: encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing number of valid frames in A 1-D tensor of shape (N,), containing the number of
encoder_out before padding. valid frames in encoder_out before padding.
beam: sp:
Number of active paths during the beam search. Sentence piece generator.
temperature: rnnlm (RnnLmModel):
Softmax temperature. RNNLM
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional):
Beam size. Defaults to 4.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance. for the i-th utterance.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out, input=encoder_out,
@ -1721,26 +1693,33 @@ def modified_beam_search_ngram_rescoring(
) )
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
eos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = next(model.parameters()).device device = next(model.parameters()).device
lm_scale = ngram_lm_scale
batch_size_list = packed_encoder_out.batch_sizes.tolist() batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0) N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list) assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
B = [HypothesisList() for _ in range(N)] B = [HypothesisList() for _ in range(N)]
for i in range(N): for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm), state=init_states,
lm_score=init_score.reshape(-1)
) )
) )
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0 offset = 0
@ -1748,7 +1727,7 @@ def modified_beam_search_ngram_rescoring(
for batch_size in batch_size_list: for batch_size in batch_size_list:
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end offset = end
@ -1762,12 +1741,8 @@ def modified_beam_search_ngram_rescoring(
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat( ys_log_probs = torch.cat(
[ [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale )
for hyps in A
for hyp in hyps
]
) # (num_hyps, 1)
decoder_input = torch.tensor( decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps], [hyp.ys[-context_size:] for hyps in A for hyp in hyps],
@ -1777,10 +1752,7 @@ def modified_beam_search_ngram_rescoring(
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select( current_encoder_out = torch.index_select(
current_encoder_out, current_encoder_out,
dim=0, dim=0,
@ -1795,12 +1767,14 @@ def modified_beam_search_ngram_rescoring(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax( log_probs = logits.log_softmax(
dim=-1 dim=-1
) # (num_hyps, vocab_size) ) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1) vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1) log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size row_splits = hyps_shape.row_splits(1) * vocab_size
@ -1811,6 +1785,38 @@ def modified_beam_search_ngram_rescoring(
shape=log_probs_shape, value=log_probs shape=log_probs_shape, value=log_probs
) )
# for all hyps with a non-blank new token, score it
token_list = []
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
hs.append(hyp.state[0])
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs))
count = 0 # index, used to locate score and lm states
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@ -1823,21 +1829,29 @@ def modified_beam_search_ngram_rescoring(
hyp_idx = topk_hyp_indexes[k] hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx] hyp = A[i][hyp_idx]
new_ys = hyp.ys[:] ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
else:
state_cost = hyp.state_cost
# We only keep AM scores in new_hyp.log_prob ys.append(new_token)
new_log_prob = ( hyp_log_prob += (
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale lm_score[new_token] * lm_scale
) ) # add the lm score
lm_score = scores[count]
state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1))
count += 1
new_hyp = Hypothesis( new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score
) )
B[i].add(new_hyp) B[i].add(new_hyp)

View File

@ -18,8 +18,9 @@ import logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import k2
from icefall.utils import make_pad_mask from icefall.utils import add_eos, add_sos, make_pad_mask
class RnnLmModel(torch.nn.Module): class RnnLmModel(torch.nn.Module):
@ -72,6 +73,8 @@ class RnnLmModel(torch.nn.Module):
else: else:
logging.info("Not tying weights") logging.info("Not tying weights")
self.cache = {}
def forward( def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
@ -118,3 +121,124 @@ class RnnLmModel(torch.nn.Module):
nll_loss = nll_loss.reshape(batch_size, -1) nll_loss = nll_loss.reshape(batch_size, -1)
return nll_loss return nll_loss
def get_init_states(self, sos):
p = next(self.parameters())
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
device = next(self.parameters()).device
batch_size = len(token_lens)
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = (
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
)
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding)
logits = self.output_linear(rnn_out)
mask = torch.zeros(logits.shape).bool().to(device)
for i in range(batch_size):
mask[i, token_lens[i], :] = True
logits = logits[mask].reshape(batch_size, -1)
return logits[:,:].log_softmax(-1), states
def clean_cache(self):
self.cache = {}
def score_token(self, tokens: torch.Tensor, state=None):
device = next(self.parameters()).device
batch_size = tokens.size(0)
if state:
h,c = state
else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
embedding = self.input_embedding(tokens)
rnn_out, states = self.rnn(embedding, (h,c))
logits = self.output_linear(rnn_out)
return logits[:,0].log_softmax(-1), states
def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None):
batch_size = len(token_lens)
if state:
h,c = state
else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
device = next(self.parameters()).device
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = (
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
)
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding, (h,c))
logits = self.output_linear(rnn_out)
return logits, states
if __name__=="__main__":
LM = RnnLmModel(500, 2048, 2048, 3, True)
h0 = torch.zeros(3, 1, 2048)
c0 = torch.zeros(3, 1, 2048)
seq = [[0,1,2,3]]
seq_lens = [len(s) for s in seq]
tokens = k2.RaggedTensor(seq)
output1, state = LM.forward_with_state(
tokens,
seq_lens,
1,
1,
0,
state=(h0,c0)
)
seq = [[0,1,2,3,4]]
seq_lens = [len(s) for s in seq]
tokens = k2.RaggedTensor(seq)
output2, _ = LM.forward_with_state(
tokens,
seq_lens,
1,
1,
0,
state=(h0,c0)
)
seq = [[4]]
seq_lens = [len(s) for s in seq]
output3 = LM.score_token(seq, seq_lens, state)
print("Finished")