mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
resolve conflicts
This commit is contained in:
parent
0df597291f
commit
bdaeaae1ae
@ -16,7 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -729,7 +729,7 @@ class Hypothesis:
|
|||||||
|
|
||||||
# timestamp[i] is the frame index after subsampling
|
# timestamp[i] is the frame index after subsampling
|
||||||
# on which ys[i] is decoded
|
# on which ys[i] is decoded
|
||||||
timestamp: List[int] = None
|
timestamp: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
# the lm score for next token given the current ys
|
# the lm score for next token given the current ys
|
||||||
lm_score: Optional[torch.Tensor] = None
|
lm_score: Optional[torch.Tensor] = None
|
||||||
@ -1870,6 +1870,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
rnnlm: RnnLmModel,
|
rnnlm: RnnLmModel,
|
||||||
rnnlm_scale: float,
|
rnnlm_scale: float,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
return_timestamps: bool = False,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Modified_beam_search + RNNLM shallow fusion
|
"""Modified_beam_search + RNNLM shallow fusion
|
||||||
|
|
||||||
@ -1930,6 +1931,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
state=init_states,
|
state=init_states,
|
||||||
lm_score=init_score.reshape(-1),
|
lm_score=init_score.reshape(-1),
|
||||||
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1938,7 +1940,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for batch_size in batch_size_list:
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||||
@ -2060,9 +2062,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
|
|
||||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
ys.append(new_token)
|
ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
hyp_log_prob += (
|
hyp_log_prob += (
|
||||||
lm_score[new_token] * lm_scale
|
lm_score[new_token] * lm_scale
|
||||||
) # add the lm score
|
) # add the lm score
|
||||||
@ -2075,7 +2079,11 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
new_hyp = Hypothesis(
|
new_hyp = Hypothesis(
|
||||||
ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score
|
ys=ys,
|
||||||
|
log_prob=hyp_log_prob,
|
||||||
|
state=state,
|
||||||
|
lm_score=lm_score,
|
||||||
|
timestampe=new_timestamp,
|
||||||
)
|
)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
@ -2083,9 +2091,18 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||||
ans = []
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
return ans
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user