This commit is contained in:
marcoyang 2022-11-02 17:48:58 +08:00
parent babcfd4b68
commit 6c8d1f9ef5

View File

@ -17,16 +17,23 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.rnn_lm.model import RnnLmModel from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import add_eos, add_sos, get_texts from icefall.utils import (
DecodingResults,
add_eos,
add_sos,
get_texts,
get_texts_with_timestamp,
)
def fast_beam_search_one_best( def fast_beam_search_one_best(
@ -38,7 +45,8 @@ def fast_beam_search_one_best(
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then A lattice is first obtained using fast beam search, and then
@ -62,8 +70,12 @@ def fast_beam_search_one_best(
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -77,8 +89,11 @@ def fast_beam_search_one_best(
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps if not return_timestamps:
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_LG( def fast_beam_search_nbest_LG(
@ -93,7 +108,8 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
The process to get the results is: The process to get the results is:
@ -130,8 +146,12 @@ def fast_beam_search_nbest_LG(
single precision. single precision.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -196,9 +216,10 @@ def fast_beam_search_nbest_LG(
best_hyp_indexes = ragged_tot_scores.argmax() best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path) if not return_timestamps:
return get_texts(best_path)
return hyps else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest( def fast_beam_search_nbest(
@ -213,7 +234,8 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
The process to get the results is: The process to get the results is:
@ -250,8 +272,12 @@ def fast_beam_search_nbest(
single precision. single precision.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -280,9 +306,10 @@ def fast_beam_search_nbest(
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path) if not return_timestamps:
return get_texts(best_path)
return hyps else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_oracle( def fast_beam_search_nbest_oracle(
@ -298,7 +325,8 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then A lattice is first obtained using fast beam search, and then
@ -339,8 +367,12 @@ def fast_beam_search_nbest_oracle(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -379,8 +411,10 @@ def fast_beam_search_nbest_oracle(
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path) if not return_timestamps:
return hyps return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search( def fast_beam_search(
@ -470,8 +504,11 @@ def fast_beam_search(
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer,
) -> List[int]: encoder_out: torch.Tensor,
max_sym_per_frame: int,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance. """Greedy search for a single utterance.
Args: Args:
model: model:
@ -481,8 +518,12 @@ def greedy_search(
max_sym_per_frame: max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%. would be 100%.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -508,6 +549,10 @@ def greedy_search(
t = 0 t = 0
hyp = [blank_id] * context_size hyp = [blank_id] * context_size
# timestamp[i] is the frame index after subsampling
# on which hyp[i] is decoded
timestamp = []
# Maximum symbols per utterance. # Maximum symbols per utterance.
max_sym_per_utt = 1000 max_sym_per_utt = 1000
@ -534,6 +579,7 @@ def greedy_search(
y = logits.argmax().item() y = logits.argmax().item()
if y not in (blank_id, unk_id): if y not in (blank_id, unk_id):
hyp.append(y) hyp.append(y)
timestamp.append(t)
decoder_input = torch.tensor( decoder_input = torch.tensor(
[hyp[-context_size:]], device=device [hyp[-context_size:]], device=device
).reshape(1, context_size) ).reshape(1, context_size)
@ -548,14 +594,21 @@ def greedy_search(
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks hyp = hyp[context_size:] # remove blanks
return hyp if not return_timestamps:
return hyp
else:
return DecodingResults(
tokens=[hyp],
timestamps=[timestamp],
)
def greedy_search_batch( def greedy_search_batch(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
model: model:
@ -565,9 +618,12 @@ def greedy_search_batch(
encoder_out_lens: encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding. encoder_out before padding.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return a list-of-list of token IDs containing the decoded results. If return_timestamps is False, return the decoded result.
len(ans) equals to encoder_out.size(0). Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -592,6 +648,10 @@ def greedy_search_batch(
hyps = [[blank_id] * context_size for _ in range(N)] hyps = [[blank_id] * context_size for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
device=device, device=device,
@ -605,7 +665,7 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0 offset = 0
for batch_size in batch_size_list: for (t, batch_size) in enumerate(batch_size_list):
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] current_encoder_out = encoder_out.data[start:end]
@ -627,6 +687,7 @@ def greedy_search_batch(
for i, v in enumerate(y): for i, v in enumerate(y):
if v not in (blank_id, unk_id): if v not in (blank_id, unk_id):
hyps[i].append(v) hyps[i].append(v)
timestamps[i].append(t)
emitted = True emitted = True
if emitted: if emitted:
# update decoder output # update decoder output
@ -641,11 +702,19 @@ def greedy_search_batch(
sorted_ans = [h[context_size:] for h in hyps] sorted_ans = [h[context_size:] for h in hyps]
ans = [] ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
return ans if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
@dataclass @dataclass
@ -657,9 +726,12 @@ class Hypothesis:
# The log prob of ys. # The log prob of ys.
# It contains only one entry. # It contains only one entry.
log_prob: torch.Tensor log_prob: torch.Tensor
state: Optional=None
lm_score: Optional=None # timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int]
state_cost: Optional[NgramLmStateCost] = None
@property @property
def key(self) -> str: def key(self) -> str:
@ -808,7 +880,8 @@ def modified_beam_search(
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args: Args:
@ -823,9 +896,12 @@ def modified_beam_search(
Number of active paths during the beam search. Number of active paths during the beam search.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results If return_timestamps is False, return the decoded result.
for the i-th utterance. Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -853,6 +929,7 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
) )
) )
@ -860,7 +937,7 @@ def modified_beam_search(
offset = 0 offset = 0
finalized_B = [] finalized_B = []
for batch_size in batch_size_list: for (t, batch_size) in enumerate(batch_size_list):
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] current_encoder_out = encoder_out.data[start:end]
@ -938,30 +1015,44 @@ def modified_beam_search(
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k] new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = [] ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
return ans if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
def _deprecated_modified_beam_search( def _deprecated_modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[int]: return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference. It decodes only one utterance at a time. We keep it only for reference.
@ -976,8 +1067,13 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -997,6 +1093,7 @@ def _deprecated_modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
) )
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -1055,17 +1152,24 @@ def _deprecated_modified_beam_search(
for i in range(len(topk_hyp_indexes)): for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]] hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_timestamp = hyp.timestamp[:]
new_token = topk_token_indexes[i] new_token = topk_token_indexes[i]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B.add(new_hyp) B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys if not return_timestamps:
return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def beam_search( def beam_search(
@ -1073,7 +1177,8 @@ def beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[int]: return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1088,8 +1193,13 @@ def beam_search(
Beam size. Beam size.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -1116,7 +1226,7 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -1177,7 +1287,13 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(
Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
timestamp=y_star.timestamp[:],
)
)
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
@ -1186,7 +1302,14 @@ def beam_search(
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) new_timestamp = y_star.timestamp + [t]
A.add(
Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
)
)
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable
# than the most probable in A # than the most probable in A
@ -1202,7 +1325,11 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
if not return_timestamps:
return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def fast_beam_search_with_nbest_rescoring( def fast_beam_search_with_nbest_rescoring(
@ -1222,7 +1349,8 @@ def fast_beam_search_with_nbest_rescoring(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
) -> Dict[str, List[List[int]]]: return_timestamps: bool = False,
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the and rescored using a given language model. The shortest path within the
@ -1264,10 +1392,13 @@ def fast_beam_search_with_nbest_rescoring(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result in a dict, where the key has the form Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the 'ngram_lm_scale_xx' and the value is the decoded results
ngram LM scale value used during decoding, i.e., 0.1. optionally with timestamps. `xx` is the ngram LM scale value
used during decoding, i.e., 0.1.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -1345,16 +1476,18 @@ def fast_beam_search_with_nbest_rescoring(
log_semiring=False, log_semiring=False,
) )
ans: Dict[str, List[List[int]]] = {} ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
for s in ngram_lm_scale_list: for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}" key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax() max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps if not return_timestamps:
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans return ans
@ -1378,7 +1511,8 @@ def fast_beam_search_with_nbest_rnn_rescoring(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
) -> Dict[str, List[List[int]]]: return_timestamps: bool = False,
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm. and rescored using a given language model and a rnn-lm.
@ -1424,10 +1558,13 @@ def fast_beam_search_with_nbest_rnn_rescoring(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result in a dict, where the key has the form Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the 'ngram_lm_scale_xx' and the value is the decoded results
ngram LM scale value used during decoding, i.e., 0.1. optionally with timestamps. `xx` is the ngram LM scale value
used during decoding, i.e., 0.1.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -1539,12 +1676,185 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax() max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps if not return_timestamps:
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans return ans
def modified_beam_search_ngram_rescoring(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ngram_lm: NgramLm,
ngram_lm_scale: float,
beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
lm_scale = ngram_lm_scale
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm),
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
for hyps in A
for hyp in hyps
]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
else:
state_cost = hyp.state_cost
# We only keep AM scores in new_hyp.log_prob
new_log_prob = (
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
)
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
def modified_beam_search_rnnlm_shallow_fusion( def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -1592,7 +1902,6 @@ def modified_beam_search_rnnlm_shallow_fusion(
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>") sos_id = sp.piece_to_id("<sos/eos>")
eos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = next(model.parameters()).device device = next(model.parameters()).device
@ -1613,7 +1922,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, state=init_states,
lm_score=init_score.reshape(-1) lm_score=init_score.reshape(-1),
) )
) )
@ -1625,7 +1934,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
for batch_size in batch_size_list: for batch_size in batch_size_list:
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end offset = end
@ -1665,9 +1974,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax( log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
dim=-1
) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
@ -1683,7 +1990,6 @@ def modified_beam_search_rnnlm_shallow_fusion(
shape=log_probs_shape, value=log_probs shape=log_probs_shape, value=log_probs
) )
# for all hyps with a non-blank new token, score it # for all hyps with a non-blank new token, score it
token_list = [] token_list = []
hs = [] hs = []
@ -1708,13 +2014,18 @@ def modified_beam_search_rnnlm_shallow_fusion(
cs.append(hyp.state[1]) cs.append(hyp.state[1])
# forward RNNLM to get new states and scores # forward RNNLM to get new states and scores
if len(token_list) != 0: if len(token_list) != 0:
tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1) tokens_to_score = (
torch.tensor(token_list)
.to(torch.int64)
.to(device)
.reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device) hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device) cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs)) scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
count = 0 # index, used to locate score and lm states count = 0 # index, used to locate score and lm states
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@ -1742,14 +2053,14 @@ def modified_beam_search_rnnlm_shallow_fusion(
) # add the lm score ) # add the lm score
lm_score = scores[count] lm_score = scores[count]
state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1)) state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1 count += 1
new_hyp = Hypothesis( new_hyp = Hypothesis(
ys=ys, ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score
) )
B[i].add(new_hyp) B[i].add(new_hyp)