This commit is contained in:
marcoyang 2022-11-02 17:48:58 +08:00
parent babcfd4b68
commit 6c8d1f9ef5

View File

@ -17,16 +17,23 @@
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import k2
import sentencepiece as spm
import torch
from model import Transducer
from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import add_eos, add_sos, get_texts
from icefall.utils import (
DecodingResults,
add_eos,
add_sos,
get_texts,
get_texts_with_timestamp,
)
def fast_beam_search_one_best(
@ -38,7 +45,8 @@ def fast_beam_search_one_best(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
) -> List[List[int]]:
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then
@ -62,8 +70,12 @@ def fast_beam_search_one_best(
Max contexts pre stream per frame.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
lattice = fast_beam_search(
model=model,
@ -77,8 +89,11 @@ def fast_beam_search_one_best(
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
if not return_timestamps:
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_LG(
@ -93,7 +108,8 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]:
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
@ -130,8 +146,12 @@ def fast_beam_search_nbest_LG(
single precision.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
lattice = fast_beam_search(
model=model,
@ -196,9 +216,10 @@ def fast_beam_search_nbest_LG(
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
if not return_timestamps:
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest(
@ -213,7 +234,8 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]:
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
@ -250,8 +272,12 @@ def fast_beam_search_nbest(
single precision.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
lattice = fast_beam_search(
model=model,
@ -280,9 +306,10 @@ def fast_beam_search_nbest(
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
if not return_timestamps:
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_oracle(
@ -298,7 +325,8 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]:
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then
@ -339,8 +367,12 @@ def fast_beam_search_nbest_oracle(
yields more unique paths.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
lattice = fast_beam_search(
model=model,
@ -379,8 +411,10 @@ def fast_beam_search_nbest_oracle(
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
if not return_timestamps:
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search(
@ -470,8 +504,11 @@ def fast_beam_search(
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
model: Transducer,
encoder_out: torch.Tensor,
max_sym_per_frame: int,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance.
Args:
model:
@ -481,8 +518,12 @@ def greedy_search(
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
@ -508,6 +549,10 @@ def greedy_search(
t = 0
hyp = [blank_id] * context_size
# timestamp[i] is the frame index after subsampling
# on which hyp[i] is decoded
timestamp = []
# Maximum symbols per utterance.
max_sym_per_utt = 1000
@ -534,6 +579,7 @@ def greedy_search(
y = logits.argmax().item()
if y not in (blank_id, unk_id):
hyp.append(y)
timestamp.append(t)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
).reshape(1, context_size)
@ -548,14 +594,21 @@ def greedy_search(
t += 1
hyp = hyp[context_size:] # remove blanks
return hyp
if not return_timestamps:
return hyp
else:
return DecodingResults(
tokens=[hyp],
timestamps=[timestamp],
)
def greedy_search_batch(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
@ -565,9 +618,12 @@ def greedy_search_batch(
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
return_timestamps:
Whether to return timestamps.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -592,6 +648,10 @@ def greedy_search_batch(
hyps = [[blank_id] * context_size for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
@ -605,7 +665,7 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for batch_size in batch_size_list:
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
@ -627,6 +687,7 @@ def greedy_search_batch(
for i, v in enumerate(y):
if v not in (blank_id, unk_id):
hyps[i].append(v)
timestamps[i].append(t)
emitted = True
if emitted:
# update decoder output
@ -641,11 +702,19 @@ def greedy_search_batch(
sorted_ans = [h[context_size:] for h in hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
return ans
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
@dataclass
@ -657,9 +726,12 @@ class Hypothesis:
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
state: Optional=None
lm_score: Optional=None
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int]
state_cost: Optional[NgramLmStateCost] = None
@property
def key(self) -> str:
@ -808,7 +880,8 @@ def modified_beam_search(
encoder_out_lens: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]:
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
@ -823,9 +896,12 @@ def modified_beam_search(
Number of active paths during the beam search.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -843,7 +919,7 @@ def modified_beam_search(
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
@ -853,6 +929,7 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
@ -860,7 +937,7 @@ def modified_beam_search(
offset = 0
finalized_B = []
for batch_size in batch_size_list:
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
@ -938,30 +1015,44 @@ def modified_beam_search(
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
return ans
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
@ -976,8 +1067,13 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
@ -997,6 +1093,7 @@ def _deprecated_modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(encoder_out)
@ -1055,17 +1152,24 @@ def _deprecated_modified_beam_search(
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_timestamp = hyp.timestamp[:]
new_token = topk_token_indexes[i]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
if not return_timestamps:
return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def beam_search(
@ -1073,7 +1177,8 @@ def beam_search(
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
) -> List[int]:
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1088,8 +1193,13 @@ def beam_search(
Beam size.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result.
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
@ -1116,7 +1226,7 @@ def beam_search(
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
max_sym_per_utt = 20000
@ -1177,7 +1287,13 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
B.add(
Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
timestamp=y_star.timestamp[:],
)
)
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
@ -1186,7 +1302,14 @@ def beam_search(
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
new_timestamp = y_star.timestamp + [t]
A.add(
Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
)
)
# Check whether B contains more than "beam" elements more probable
# than the most probable in A
@ -1202,7 +1325,11 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
if not return_timestamps:
return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def fast_beam_search_with_nbest_rescoring(
@ -1222,7 +1349,8 @@ def fast_beam_search_with_nbest_rescoring(
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
return_timestamps: bool = False,
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the
@ -1264,10 +1392,13 @@ def fast_beam_search_with_nbest_rescoring(
yields more unique paths.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
'ngram_lm_scale_xx' and the value is the decoded results
optionally with timestamps. `xx` is the ngram LM scale value
used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
@ -1345,16 +1476,18 @@ def fast_beam_search_with_nbest_rescoring(
log_semiring=False,
)
ans: Dict[str, List[List[int]]] = {}
ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
if not return_timestamps:
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans
@ -1378,7 +1511,8 @@ def fast_beam_search_with_nbest_rnn_rescoring(
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
return_timestamps: bool = False,
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm.
@ -1424,10 +1558,13 @@ def fast_beam_search_with_nbest_rnn_rescoring(
yields more unique paths.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
'ngram_lm_scale_xx' and the value is the decoded results
optionally with timestamps. `xx` is the ngram LM scale value
used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
@ -1539,12 +1676,185 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
if not return_timestamps:
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans
def modified_beam_search_ngram_rescoring(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ngram_lm: NgramLm,
ngram_lm_scale: float,
beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
lm_scale = ngram_lm_scale
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm),
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
for hyps in A
for hyp in hyps
]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
else:
state_cost = hyp.state_cost
# We only keep AM scores in new_hyp.log_prob
new_log_prob = (
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
)
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
@ -1559,18 +1869,18 @@ def modified_beam_search_rnnlm_shallow_fusion(
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
sp:
Sentence piece generator.
rnnlm (RnnLmModel):
rnnlm (RnnLmModel):
RNNLM
rnnlm_scale (float):
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional):
beam (int, optional):
Beam size. Defaults to 4.
Returns:
@ -1582,7 +1892,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
@ -1592,20 +1902,19 @@ def modified_beam_search_rnnlm_shallow_fusion(
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
eos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
@ -1613,19 +1922,19 @@ def modified_beam_search_rnnlm_shallow_fusion(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1)
lm_score=init_score.reshape(-1),
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
@ -1637,44 +1946,42 @@ def modified_beam_search_rnnlm_shallow_fusion(
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(
dim=-1
) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
@ -1682,7 +1989,6 @@ def modified_beam_search_rnnlm_shallow_fusion(
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
# for all hyps with a non-blank new token, score it
token_list = []
@ -1698,7 +2004,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
@ -1708,13 +2014,18 @@ def modified_beam_search_rnnlm_shallow_fusion(
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1)
tokens_to_score = (
torch.tensor(token_list)
.to(torch.int64)
.to(device)
.reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs))
count = 0 # index, used to locate score and lm states
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@ -1722,36 +2033,36 @@ def modified_beam_search_rnnlm_shallow_fusion(
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
hyp_log_prob += (
lm_score[new_token] * lm_scale
) # add the lm score
lm_score = scores[count]
state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1))
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score
ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score
)
B[i].add(new_hyp)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
@ -1762,4 +2073,4 @@ def modified_beam_search_rnnlm_shallow_fusion(
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
return ans