mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
fix bugs
This commit is contained in:
parent
4eb356ce49
commit
07587d106a
@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@ -750,6 +751,9 @@ class Hypothesis:
|
||||
"""Return a string representation of self.ys"""
|
||||
return "_".join(map(str, self.ys))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}"
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
@ -887,6 +891,7 @@ def modified_beam_search(
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
num_context_history: int = 1,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
@ -938,7 +943,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=ContextState(state_id=0),
|
||||
context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
@ -961,6 +966,7 @@ def modified_beam_search(
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
@ -1018,30 +1024,24 @@ def modified_beam_search(
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
new_context_state = None
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state.clone()
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
if context_graph is not None:
|
||||
new_context_state = context_graph.get_next_state(
|
||||
hyp.context_state.state_id, new_token
|
||||
)
|
||||
new_log_prob = topk_log_probs[k] + (
|
||||
0
|
||||
if new_context_state is None
|
||||
else new_context_state.score
|
||||
)
|
||||
context_score, new_context_state = hyp.context_state.forward_one_step(new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k] + context_score
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
context_state=hyp.context_state
|
||||
if new_context_state is None
|
||||
else new_context_state,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
@ -1053,20 +1053,15 @@ def modified_beam_search(
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
if hyp.context_state.state_id != 0:
|
||||
new_context_state = context_graph.get_next_state(
|
||||
hyp.context_state.state_id, 0
|
||||
context_score, new_context_state = hyp.context_state.finalize()
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + context_score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + new_context_state.score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
)
|
||||
else:
|
||||
finalized_B[i].add(hyp)
|
||||
)
|
||||
B = finalized_B
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
@ -131,6 +131,8 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifst
|
||||
import graphviz
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -363,6 +365,13 @@ def get_parser():
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-context-history",
|
||||
type=int,
|
||||
default=1,
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
@ -511,6 +520,7 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
num_context_history=params.num_context_history,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
@ -565,7 +575,10 @@ def decode_one_batch(
|
||||
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
key += f"-context-score-{params.context_score}"
|
||||
key += f"-num-context-history-{params.num_context_history}"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -614,7 +627,7 @@ def decode_dataset(
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
log_interval = 1
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -769,6 +782,8 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
params.suffix += f"-num-context-history-{params.num_context_history}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
@ -381,6 +381,7 @@ class LibriSpeechAsrDataModule:
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
num_buckets=2,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
@ -452,6 +453,13 @@ class LibriSpeechAsrDataModule:
|
||||
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_book_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-books cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libri_book_test_feats.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_book2_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-books2 cuts")
|
||||
|
||||
@ -287,6 +287,13 @@ def get_parser():
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-context-history",
|
||||
type=int,
|
||||
default=1,
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
@ -389,6 +396,7 @@ def decode_one_batch(
|
||||
beam=params.beam_size,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_graph=context_graph,
|
||||
num_context_history=params.num_context_history,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -429,7 +437,7 @@ def decode_one_batch(
|
||||
}
|
||||
else:
|
||||
return {
|
||||
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
|
||||
f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps
|
||||
}
|
||||
|
||||
|
||||
@ -568,6 +576,7 @@ def main():
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
params.suffix += f"-num-context-history-{params.num_context_history}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
@ -14,11 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from heapq import heappush, heappop
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
import argparse
|
||||
import k2
|
||||
import kaldifst
|
||||
@ -27,17 +26,13 @@ import sentencepiece as spm
|
||||
from icefall.utils import is_module_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextState:
|
||||
state_id: int = 0
|
||||
score: float = 0.0
|
||||
|
||||
|
||||
class ContextGraph:
|
||||
def __init__(self, context_score: float = 1):
|
||||
self.context_score = context_score
|
||||
|
||||
def build_context_graph_char(self, contexts: List[str], token_table: k2.SymbolTable):
|
||||
def build_context_graph_char(
|
||||
self, contexts: List[str], token_table: k2.SymbolTable
|
||||
):
|
||||
"""Convert a list of texts to a list-of-list of token IDs.
|
||||
|
||||
Args:
|
||||
@ -56,7 +51,7 @@ class ContextGraph:
|
||||
whitespace = re.compile(r"([ \t])")
|
||||
for text in contexts:
|
||||
text = re.sub(whitespace, "", text)
|
||||
sub_ids : List[int] = []
|
||||
sub_ids: List[int] = []
|
||||
skip = False
|
||||
for txt in text:
|
||||
if txt not in token_table:
|
||||
@ -69,7 +64,9 @@ class ContextGraph:
|
||||
ids.append(sub_ids)
|
||||
self.build_context_graph(ids)
|
||||
|
||||
def build_context_graph_bpe(self, contexts: List[str], sp: spm.SentencePieceProcessor):
|
||||
def build_context_graph_bpe(
|
||||
self, contexts: List[str], sp: spm.SentencePieceProcessor
|
||||
):
|
||||
contexts_bpe = sp.encode(contexts)
|
||||
self.build_context_graph(contexts_bpe)
|
||||
|
||||
@ -80,7 +77,7 @@ class ContextGraph:
|
||||
) # 1st state will be state 0 (returned by add_state)
|
||||
assert start_state == 0, start_state
|
||||
graph.start = 0 # set the start state to 0
|
||||
graph.set_final(start_state, weight=0) # weight is in log space
|
||||
graph.set_final(start_state, weight=kaldifst.TropicalWeight.one)
|
||||
|
||||
for tokens in token_ids:
|
||||
prev_state = start_state
|
||||
@ -111,22 +108,128 @@ class ContextGraph:
|
||||
prev_state = next_state
|
||||
backoff_score += score
|
||||
self.graph = kaldifst.determinize(graph)
|
||||
kaldifst.arcsort(self.graph)
|
||||
|
||||
def get_next_state(self, state_id: int, label: int) -> ContextState:
|
||||
next_state = 0
|
||||
score = 0
|
||||
for arc in kaldifst.ArcIterator(self.graph, state_id):
|
||||
if arc.ilabel == 0:
|
||||
score = arc.weight.value
|
||||
elif arc.ilabel == label:
|
||||
next_state = arc.nextstate
|
||||
score = arc.weight.value
|
||||
break
|
||||
return ContextState(
|
||||
state_id=next_state,
|
||||
score=score,
|
||||
def is_final_state(self, state_id: int) -> bool:
|
||||
return self.graph.final(state_id) == kaldifst.TropicalWeight.one
|
||||
|
||||
|
||||
def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]:
|
||||
arc_iter = kaldifst.ArcIterator(self.graph, state_id)
|
||||
num_arcs = self.graph.num_arcs(state_id)
|
||||
|
||||
# The LM is arc sorted by ilabel, so we use binary search below.
|
||||
left = 0
|
||||
right = num_arcs - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
arc_iter.seek(mid)
|
||||
arc = arc_iter.value
|
||||
if arc.ilabel < label:
|
||||
left = mid + 1
|
||||
elif arc.ilabel > label:
|
||||
right = mid - 1
|
||||
else:
|
||||
return (arc.nextstate, arc.weight.value, True)
|
||||
|
||||
# Backoff to state 0 with the score on epsilon arc (ilabel == 0)
|
||||
arc_iter.seek(0)
|
||||
arc = arc_iter.value
|
||||
if arc.ilabel == 0:
|
||||
return (0, 0, False)
|
||||
else:
|
||||
return (0, arc.weight.value, False)
|
||||
|
||||
|
||||
class ContextState:
|
||||
def __init__(self, graph: ContextGraph, max_states: int):
|
||||
self.graph = graph
|
||||
self.max_states = max_states
|
||||
# [(total score, (score, state_id))]
|
||||
self.states: List[Tuple[float, Tuple[float, int]]] = []
|
||||
|
||||
def __str__(self):
|
||||
return ";".join([str(state) for state in self.states])
|
||||
|
||||
def clone(self):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
new_context_state.states = self.states[:]
|
||||
return new_context_state
|
||||
|
||||
def finalize(self) -> float:
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
if len(self.states) == 0:
|
||||
return 0, new_context_state
|
||||
item = heappop(self.states)
|
||||
return item[0], new_context_state
|
||||
|
||||
def forward_one_step(self, label: int) -> float:
|
||||
states = self.states[:]
|
||||
new_states = []
|
||||
# expand current label from state state
|
||||
status = self.graph.get_next_state(0, label)
|
||||
if status[2]:
|
||||
heappush(new_states, (-status[1], (status[1], status[0])))
|
||||
else:
|
||||
assert status[0] == 0 and status[2] == False, status
|
||||
|
||||
# the score we have added to the path till now
|
||||
prev_max_total_score = 0
|
||||
# expand previous states with given label
|
||||
while states:
|
||||
state = heappop(states)
|
||||
if -state[0] > prev_max_total_score:
|
||||
prev_max_total_score = -state[0]
|
||||
|
||||
status = self.graph.get_next_state(state[1][1], label)
|
||||
|
||||
if status[2]:
|
||||
heappush(new_states, (state[0] - status[1], (status[1], status[0])))
|
||||
else:
|
||||
pass
|
||||
# assert status == (0, state[0], False), status
|
||||
num_states_drop = (
|
||||
0
|
||||
if len(new_states) <= self.max_states
|
||||
else len(new_states) - self.max_states
|
||||
)
|
||||
|
||||
states = []
|
||||
if len(new_states) == 0:
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -prev_max_total_score, new_context_state
|
||||
|
||||
item = heappop(new_states)
|
||||
|
||||
# if one item match a context, clear all states (means start a new context
|
||||
# from next label), and return the score of current label
|
||||
if self.graph.is_final_state(item[1][1]):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -item[0] - prev_max_total_score, new_context_state
|
||||
|
||||
max_total_score = -item[0]
|
||||
heappush(states, item)
|
||||
|
||||
while num_states_drop != 0:
|
||||
item = heappop(new_states)
|
||||
if self.graph.is_final_state(item[1][1]):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -item[0] - prev_max_total_score, new_context_state
|
||||
num_states_drop -= 1
|
||||
|
||||
while new_states:
|
||||
item = heappop(new_states)
|
||||
if self.graph.is_final_state(item[1][1]):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -item[0] - prev_max_total_score, new_context_state
|
||||
heappush(states, item)
|
||||
# no context matched, the matching may continue with previous prefix,
|
||||
# or change to another prefix.
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
new_context_state.states = states
|
||||
return max_total_score - prev_max_total_score, new_context_state
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user