resolve conflicts

This commit is contained in:
marcoyang 2022-11-04 11:25:10 +08:00
parent 0df597291f
commit bdaeaae1ae

View File

@ -16,7 +16,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import k2
@ -729,7 +729,7 @@ class Hypothesis:
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = None
timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
@ -1870,6 +1870,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + RNNLM shallow fusion
@ -1930,6 +1931,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
@ -1938,7 +1940,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
offset = 0
finalized_B = []
for batch_size in batch_size_list:
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
@ -2060,9 +2062,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += (
lm_score[new_token] * lm_scale
) # add the lm score
@ -2075,7 +2079,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
count += 1
new_hyp = Hypothesis(
ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestampe=new_timestamp,
)
B[i].add(new_hyp)
@ -2083,9 +2091,18 @@ def modified_beam_search_rnnlm_shallow_fusion(
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
return ans
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)