resolve conflicts

This commit is contained in:
marcoyang 2022-11-04 11:25:10 +08:00
parent 0df597291f
commit bdaeaae1ae

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import k2 import k2
@ -729,7 +729,7 @@ class Hypothesis:
# timestamp[i] is the frame index after subsampling # timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded # on which ys[i] is decoded
timestamp: List[int] = None timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys # the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None lm_score: Optional[torch.Tensor] = None
@ -1870,6 +1870,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
rnnlm: RnnLmModel, rnnlm: RnnLmModel,
rnnlm_scale: float, rnnlm_scale: float,
beam: int = 4, beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]: ) -> List[List[int]]:
"""Modified_beam_search + RNNLM shallow fusion """Modified_beam_search + RNNLM shallow fusion
@ -1930,6 +1931,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, state=init_states,
lm_score=init_score.reshape(-1), lm_score=init_score.reshape(-1),
timestamp=[],
) )
) )
@ -1938,7 +1940,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
offset = 0 offset = 0
finalized_B = [] finalized_B = []
for batch_size in batch_size_list: for (t, batch_size) in enumerate(batch_size_list):
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = encoder_out.data[start:end] # get batch
@ -2060,9 +2062,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
hyp_log_prob = topk_log_probs[k] # get score of current hyp hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
ys.append(new_token) ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += ( hyp_log_prob += (
lm_score[new_token] * lm_scale lm_score[new_token] * lm_scale
) # add the lm score ) # add the lm score
@ -2075,7 +2079,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
count += 1 count += 1
new_hyp = Hypothesis( new_hyp = Hypothesis(
ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestampe=new_timestamp,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
@ -2083,9 +2091,18 @@ def modified_beam_search_rnnlm_shallow_fusion(
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = [] ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
return ans if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)