This commit is contained in:
pkufool 2023-05-05 14:42:47 +08:00
parent 4eb356ce49
commit 07587d106a
5 changed files with 186 additions and 56 deletions

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
@ -750,6 +751,9 @@ class Hypothesis:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
def __str__(self) -> str:
return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}"
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
@ -887,6 +891,7 @@ def modified_beam_search(
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None,
num_context_history: int = 1,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
@ -938,7 +943,7 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=ContextState(state_id=0),
context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history),
timestamp=[],
)
)
@ -961,6 +966,7 @@ def modified_beam_search(
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
@ -1018,30 +1024,24 @@ def modified_beam_search(
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
new_context_state = None
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state.clone()
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
if context_graph is not None:
new_context_state = context_graph.get_next_state(
hyp.context_state.state_id, new_token
)
new_log_prob = topk_log_probs[k] + (
0
if new_context_state is None
else new_context_state.score
)
context_score, new_context_state = hyp.context_state.forward_one_step(new_token)
new_log_prob = topk_log_probs[k] + context_score
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
context_state=hyp.context_state
if new_context_state is None
else new_context_state,
context_state=new_context_state,
)
B[i].add(new_hyp)
@ -1053,20 +1053,15 @@ def modified_beam_search(
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
if hyp.context_state.state_id != 0:
new_context_state = context_graph.get_next_state(
hyp.context_state.state_id, 0
context_score, new_context_state = hyp.context_state.finalize()
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + new_context_state.score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
else:
finalized_B[i].add(hyp)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]

View File

@ -131,6 +131,8 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import kaldifst
import graphviz
import sentencepiece as spm
import torch
import torch.nn as nn
@ -363,6 +365,13 @@ def get_parser():
help="",
)
parser.add_argument(
"--num-context-history",
type=int,
default=1,
help="",
)
parser.add_argument(
"--context-file",
type=str,
@ -511,6 +520,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
num_context_history=params.num_context_history,
return_timestamps=True,
)
else:
@ -565,7 +575,10 @@ def decode_one_batch(
return {key: (hyps, timestamps)}
else:
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
key = f"beam_size_{params.beam_size}"
key += f"-context-score-{params.context_score}"
key += f"-num-context-history-{params.num_context_history}"
return {key: (hyps, timestamps)}
def decode_dataset(
@ -614,7 +627,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
log_interval = 1
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -769,6 +782,8 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}"
params.suffix += f"-num-context-history-{params.num_context_history}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -381,6 +381,7 @@ class LibriSpeechAsrDataModule:
)
sampler = DynamicBucketingSampler(
cuts,
num_buckets=2,
max_duration=self.args.max_duration,
shuffle=False,
)
@ -452,6 +453,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
)
@lru_cache()
def test_book_test_cuts(self) -> CutSet:
logging.info("About to get test-books cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_book_test_feats.jsonl.gz"
)
@lru_cache()
def test_book2_cuts(self) -> CutSet:
logging.info("About to get test-books2 cuts")

View File

@ -287,6 +287,13 @@ def get_parser():
help="",
)
parser.add_argument(
"--num-context-history",
type=int,
default=1,
help="",
)
parser.add_argument(
"--context-file",
type=str,
@ -389,6 +396,7 @@ def decode_one_batch(
beam=params.beam_size,
encoder_out_lens=encoder_out_lens,
context_graph=context_graph,
num_context_history=params.num_context_history,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -429,7 +437,7 @@ def decode_one_batch(
}
else:
return {
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps
}
@ -568,6 +576,7 @@ def main():
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}"
params.suffix += f"-num-context-history-{params.num_context_history}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -14,11 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from heapq import heappush, heappop
import re
from dataclasses import dataclass
from typing import List
from typing import List, Tuple
import argparse
import k2
import kaldifst
@ -27,17 +26,13 @@ import sentencepiece as spm
from icefall.utils import is_module_available
@dataclass
class ContextState:
state_id: int = 0
score: float = 0.0
class ContextGraph:
def __init__(self, context_score: float = 1):
self.context_score = context_score
def build_context_graph_char(self, contexts: List[str], token_table: k2.SymbolTable):
def build_context_graph_char(
self, contexts: List[str], token_table: k2.SymbolTable
):
"""Convert a list of texts to a list-of-list of token IDs.
Args:
@ -56,7 +51,7 @@ class ContextGraph:
whitespace = re.compile(r"([ \t])")
for text in contexts:
text = re.sub(whitespace, "", text)
sub_ids : List[int] = []
sub_ids: List[int] = []
skip = False
for txt in text:
if txt not in token_table:
@ -69,7 +64,9 @@ class ContextGraph:
ids.append(sub_ids)
self.build_context_graph(ids)
def build_context_graph_bpe(self, contexts: List[str], sp: spm.SentencePieceProcessor):
def build_context_graph_bpe(
self, contexts: List[str], sp: spm.SentencePieceProcessor
):
contexts_bpe = sp.encode(contexts)
self.build_context_graph(contexts_bpe)
@ -80,7 +77,7 @@ class ContextGraph:
) # 1st state will be state 0 (returned by add_state)
assert start_state == 0, start_state
graph.start = 0 # set the start state to 0
graph.set_final(start_state, weight=0) # weight is in log space
graph.set_final(start_state, weight=kaldifst.TropicalWeight.one)
for tokens in token_ids:
prev_state = start_state
@ -111,22 +108,128 @@ class ContextGraph:
prev_state = next_state
backoff_score += score
self.graph = kaldifst.determinize(graph)
kaldifst.arcsort(self.graph)
def get_next_state(self, state_id: int, label: int) -> ContextState:
next_state = 0
score = 0
for arc in kaldifst.ArcIterator(self.graph, state_id):
if arc.ilabel == 0:
score = arc.weight.value
elif arc.ilabel == label:
next_state = arc.nextstate
score = arc.weight.value
break
return ContextState(
state_id=next_state,
score=score,
def is_final_state(self, state_id: int) -> bool:
return self.graph.final(state_id) == kaldifst.TropicalWeight.one
def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]:
arc_iter = kaldifst.ArcIterator(self.graph, state_id)
num_arcs = self.graph.num_arcs(state_id)
# The LM is arc sorted by ilabel, so we use binary search below.
left = 0
right = num_arcs - 1
while left <= right:
mid = (left + right) // 2
arc_iter.seek(mid)
arc = arc_iter.value
if arc.ilabel < label:
left = mid + 1
elif arc.ilabel > label:
right = mid - 1
else:
return (arc.nextstate, arc.weight.value, True)
# Backoff to state 0 with the score on epsilon arc (ilabel == 0)
arc_iter.seek(0)
arc = arc_iter.value
if arc.ilabel == 0:
return (0, 0, False)
else:
return (0, arc.weight.value, False)
class ContextState:
def __init__(self, graph: ContextGraph, max_states: int):
self.graph = graph
self.max_states = max_states
# [(total score, (score, state_id))]
self.states: List[Tuple[float, Tuple[float, int]]] = []
def __str__(self):
return ";".join([str(state) for state in self.states])
def clone(self):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
new_context_state.states = self.states[:]
return new_context_state
def finalize(self) -> float:
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
if len(self.states) == 0:
return 0, new_context_state
item = heappop(self.states)
return item[0], new_context_state
def forward_one_step(self, label: int) -> float:
states = self.states[:]
new_states = []
# expand current label from state state
status = self.graph.get_next_state(0, label)
if status[2]:
heappush(new_states, (-status[1], (status[1], status[0])))
else:
assert status[0] == 0 and status[2] == False, status
# the score we have added to the path till now
prev_max_total_score = 0
# expand previous states with given label
while states:
state = heappop(states)
if -state[0] > prev_max_total_score:
prev_max_total_score = -state[0]
status = self.graph.get_next_state(state[1][1], label)
if status[2]:
heappush(new_states, (state[0] - status[1], (status[1], status[0])))
else:
pass
# assert status == (0, state[0], False), status
num_states_drop = (
0
if len(new_states) <= self.max_states
else len(new_states) - self.max_states
)
states = []
if len(new_states) == 0:
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -prev_max_total_score, new_context_state
item = heappop(new_states)
# if one item match a context, clear all states (means start a new context
# from next label), and return the score of current label
if self.graph.is_final_state(item[1][1]):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -item[0] - prev_max_total_score, new_context_state
max_total_score = -item[0]
heappush(states, item)
while num_states_drop != 0:
item = heappop(new_states)
if self.graph.is_final_state(item[1][1]):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -item[0] - prev_max_total_score, new_context_state
num_states_drop -= 1
while new_states:
item = heappop(new_states)
if self.graph.is_final_state(item[1][1]):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -item[0] - prev_max_total_score, new_context_state
heappush(states, item)
# no context matched, the matching may continue with previous prefix,
# or change to another prefix.
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
new_context_state.states = states
return max_total_score - prev_max_total_score, new_context_state
if __name__ == "__main__":
parser = argparse.ArgumentParser()