mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
resolve conflicts
This commit is contained in:
parent
0df597291f
commit
bdaeaae1ae
@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
@ -729,7 +729,7 @@ class Hypothesis:
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int] = None
|
||||
timestamp: List[int] = field(default_factory=list)
|
||||
|
||||
# the lm score for next token given the current ys
|
||||
lm_score: Optional[torch.Tensor] = None
|
||||
@ -1870,6 +1870,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
rnnlm: RnnLmModel,
|
||||
rnnlm_scale: float,
|
||||
beam: int = 4,
|
||||
return_timestamps: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""Modified_beam_search + RNNLM shallow fusion
|
||||
|
||||
@ -1930,6 +1931,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states,
|
||||
lm_score=init_score.reshape(-1),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
@ -1938,7 +1940,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for batch_size in batch_size_list:
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||
@ -2060,9 +2062,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
hyp_log_prob += (
|
||||
lm_score[new_token] * lm_scale
|
||||
) # add the lm score
|
||||
@ -2075,7 +2079,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
count += 1
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score
|
||||
ys=ys,
|
||||
log_prob=hyp_log_prob,
|
||||
state=state,
|
||||
lm_score=lm_score,
|
||||
timestampe=new_timestamp,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
@ -2083,9 +2091,18 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user