Add librispeech prefix-beam-search

This commit is contained in:
pkufool 2024-09-27 19:31:54 +08:00
parent cef16574f7
commit 02e00ff504
3 changed files with 366 additions and 61 deletions

View File

@ -123,6 +123,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
from lhotse import set_caching_enabled from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall.context_graph import ContextGraph, ContextState
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lm_wrapper import LmScorer
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -131,6 +135,9 @@ from icefall.checkpoint import (
) )
from icefall.decode import ( from icefall.decode import (
ctc_greedy_search, ctc_greedy_search,
ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
nbest_oracle, nbest_oracle,
@ -280,6 +287,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument( parser.add_argument(
"--hlg-scale", "--hlg-scale",
type=float, type=float,
@ -301,7 +325,7 @@ def get_parser():
"--skip-scoring", "--skip-scoring",
type=str2bool, type=str2bool,
default=False, default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""" help="""Skip scoring, but still save the ASR output (for eval sets).""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -314,8 +338,9 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10, "frame_shift_ms": 10,
"search_beam": 20, "search_beam": 20, # for k2 fsa composition
"output_beam": 8, "output_beam": 8, # for k2 fsa composition
"beam": 4, # for prefix-beam-search
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
@ -333,6 +358,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -377,10 +403,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None. the returned dict. Note: If it decodes to nothing, then return None.
""" """
if HLG is not None: device = params.device
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -411,6 +434,48 @@ def decode_one_batch(
key = "ctc-greedy-search" key = "ctc-greedy-search"
return {key: hyps} return {key: hyps}
if params.decoding_method == "prefix-beam-search":
token_ids = ctc_prefix_beam_search(
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search"
return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output=ctc_output,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
ans = dict()
for a_scale_str, token_ids in best_path_dict.items():
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
LM=LM,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search-shallow-fussion"
return {key: hyps}
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
@ -584,6 +649,7 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -634,6 +700,7 @@ def decode_dataset(
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
LM=LM,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -664,9 +731,7 @@ def save_asr_output(
""" """
for key, results in results_dict.items(): for key, results in results_dict.items():
recogs_filename = ( recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
)
results = sorted(results) results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results) store_transcripts(filename=recogs_filename, texts=results)
@ -680,7 +745,8 @@ def save_wer_results(
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
if params.decoding_method in ( if params.decoding_method in (
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" "attention-decoder-rescoring-with-ngram",
"whole-lattice-rescoring",
): ):
# Set it to False since there are too many logs. # Set it to False since there are too many logs.
enable_log = False enable_log = False
@ -721,6 +787,7 @@ def save_wer_results(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
@ -735,8 +802,11 @@ def main():
set_caching_enabled(True) # lhotse set_caching_enabled(True) # lhotse
assert params.decoding_method in ( assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding", "ctc-decoding",
"ctc-greedy-search",
"prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"1best", "1best",
"nbest", "nbest",
"nbest-rescoring", "nbest-rescoring",
@ -762,6 +832,11 @@ def main():
params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}" params.suffix += f"_left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
params.suffix += f"_lm-scale-{params.lm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "_use-averaged-model" params.suffix += "_use-averaged-model"
@ -772,6 +847,8 @@ def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
logging.info(params) logging.info(params)
@ -786,14 +863,24 @@ def main():
params.sos_id = 1 params.sos_id = 1
if params.decoding_method in [ if params.decoding_method in [
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" "ctc-greedy-search",
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
"prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
]: ]:
HLG = None HLG = None
H = k2.ctc_topo( H = None
max_token=max_token_id, if params.decoding_method in [
modified=False, "ctc-decoding",
device=device, "attention-decoder-rescoring-no-ngram",
) ]:
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor() bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model")) bpe_model.load(str(params.lang_dir / "bpe.model"))
else: else:
@ -844,7 +931,8 @@ def main():
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method in [ if params.decoding_method in [
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" "whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]: ]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
@ -858,6 +946,19 @@ def main():
else: else:
G = None G = None
# only load the neural network LM if required
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
@ -967,6 +1068,7 @@ def main():
bpe_model=bpe_model, bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
LM=LM,
) )
save_asr_output( save_asr_output(

View File

@ -1511,30 +1511,43 @@ def ctc_greedy_search(
class Hypothesis: class Hypothesis:
# The predicted tokens so far. # The predicted tokens so far.
# Newly predicted tokens are appended to `ys`. # Newly predicted tokens are appended to `ys`.
ys: List[int] ys: List[int] = field(default_factory=list)
# The log prob of ys. # The log prob of ys.
# It contains only one entry. # It contains only one entry.
log_prob_blank: torch.Tensor log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)
log_prob_non_blank: torch.Tensor log_prob_non_blank: torch.Tensor = torch.tensor(
[float("-inf")], dtype=torch.float32
)
# timestamp[i] is the frame index after subsampling # timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded # on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list) timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys # The lm score of ys
lm_score: Optional[torch.Tensor] = None # It contains only one entry
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)
# the lm log_probs for next token given the history ys
lm_log_probs: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM) # the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state # N-gram LM state
state_cost: Optional[NgramLmStateCost] = None LODR_state: Optional[NgramLmStateCost] = None
# N-gram LM state
Ngram_state: Optional[NgramLmStateCost] = None
# Context graph state # Context graph state
context_state: Optional[ContextState] = None context_state: Optional[ContextState] = None
@property
def tot_score(self) -> torch.Tensor:
return self.log_prob + self.lm_score
@property @property
def log_prob(self) -> torch.Tensor: def log_prob(self) -> torch.Tensor:
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank) return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
@ -1544,6 +1557,20 @@ class Hypothesis:
"""Return a tuple representation of self.ys""" """Return a tuple representation of self.ys"""
return tuple(self.ys) return tuple(self.ys)
def clone(self) -> "Hypothesis":
return Hypothesis(
ys=self.ys,
log_prob_blank=self.log_prob_blank,
log_prob_non_blank=self.log_prob_non_blank,
timestamp=self.timestamp,
lm_log_probs=self.lm_log_probs,
lm_score=self.lm_score,
state=self.state,
LODR_state=self.LODR_state,
Ngram_state=self.Ngram_state,
context_state=self.context_state,
)
class HypothesisList(object): class HypothesisList(object):
def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None: def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None:
@ -1597,9 +1624,9 @@ class HypothesisList(object):
Return the hypothesis that has the largest `log_prob`. Return the hypothesis that has the largest `log_prob`.
""" """
if length_norm: if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
else: else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob) return max(self._data.values(), key=lambda hyp: hyp.tot_score)
def remove(self, hyp: Hypothesis) -> None: def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis. """Remove a given hypothesis.
@ -1629,7 +1656,7 @@ class HypothesisList(object):
""" """
ans = HypothesisList() ans = HypothesisList()
for _, hyp in self._data.items(): for _, hyp in self._data.items():
if hyp.log_prob > threshold: if hyp.tot_score > threshold:
ans.add(hyp) # shallow copy ans.add(hyp) # shallow copy
return ans return ans
@ -1645,17 +1672,20 @@ class HypothesisList(object):
if length_norm: if length_norm:
hyps = sorted( hyps = sorted(
hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True
)[:k] )[:k]
else: else:
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k]
ans = HypothesisList(dict(hyps)) ans = HypothesisList(dict(hyps))
return ans return ans
def __contains__(self, key: str): def __contains__(self, key: tuple):
return key in self._data return key in self._data
def __getitem__(self, key: tuple):
return self._data[key]
def __iter__(self): def __iter__(self):
return iter(self._data.values()) return iter(self._data.values())
@ -1694,64 +1724,96 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
return ans return ans
def _step_worker(log_probs, indexes, B, beam, blank_id): def _step_worker(
log_probs,
indexes,
B,
beam,
blank_id,
lm_scale: float = 0,
LODR_lm_scale: float = 0,
context_graph: Optional[ContextGraph] = None,
):
A = list(B) A = list(B)
B = HypothesisList() B = HypothesisList()
for h in range(len(A)): for h in range(len(A)):
hyp = A[h] hyp = A[h]
for k in range(log_probs.size(0)): for k in range(log_probs.size(0)):
log_prob, index = log_probs[k], indexes[k] log_prob, index = log_probs[k], indexes[k]
if index == blank_id: new_token = index.item()
update_prefix = False
new_hyp = hyp.clone()
if new_token == blank_id:
# Case 0: *a + ε => *a # Case 0: *a + ε => *a
# *aε + ε => *a # *aε + ε => *a
# Prefix does not change, update log_prob of blank # Prefix does not change, update log_prob of blank
new_hyp = Hypothesis( new_hyp.log_prob_non_blank = torch.tensor(
ys=hyp.ys[:], [float("-inf")], dtype=torch.float32
log_prob_non_blank=torch.tensor(
[float("-inf")], dtype=torch.float32
),
log_prob_blank=hyp.log_prob + log_prob,
) )
new_hyp.log_prob_blank = hyp.log_prob + log_prob
B.add(new_hyp) B.add(new_hyp)
elif len(hyp.ys) > 0 and hyp.ys[-1] == index: elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token:
# Case 1: *a + a => *a # Case 1: *a + a => *a
# Prefix does not change, update log_prob of non_blank # Prefix does not change, update log_prob of non_blank
new_hyp = Hypothesis( new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob
ys=hyp.ys[:], new_hyp.log_prob_blank = torch.tensor(
log_prob_non_blank=hyp.log_prob_non_blank + log_prob, [float("-inf")], dtype=torch.float32
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32),
) )
B.add(new_hyp) B.add(new_hyp)
# Case 2: *aε + a => *aa # Case 2: *aε + a => *aa
# Prefix changes, update log_prob of blank # Prefix changes, update log_prob of blank
new_hyp = Hypothesis( new_hyp = hyp.clone()
ys=hyp.ys[:] + [index.item()], # Caution: DO NOT use append, as clone is shallow copy
log_prob_non_blank=hyp.log_prob_blank + log_prob, new_hyp.ys = hyp.ys + [new_token]
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob
new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
) )
B.add(new_hyp) update_prefix = True
else: else:
# Case 3: *a + b => *ab, *aε + b => *ab # Case 3: *a + b => *ab, *aε + b => *ab
# Prefix changes, update log_prob of non_blank # Prefix changes, update log_prob of non_blank
new_hyp = Hypothesis( # Caution: DO NOT use append, as clone is shallow copy
ys=hyp.ys[:] + [index.item()], new_hyp.ys = hyp.ys + [new_token]
log_prob_non_blank=hyp.log_prob + log_prob, new_hyp.log_prob_non_blank = hyp.log_prob + log_prob
log_prob_blank=torch.tensor([float("-inf")], dtype=torch.float32), new_hyp.log_prob_blank = torch.tensor(
[float("-inf")], dtype=torch.float32
) )
update_prefix = True
if update_prefix:
lm_score = hyp.lm_score
if hyp.lm_log_probs is not None:
lm_score += hyp.lm_log_probs[new_token] * lm_scale
new_hyp.lm_log_probs = None
if context_graph is not None and hyp.context_state is not None:
context_score, new_context_state = context_graph.forward_one_step(
hyp.context_state, new_token
)
lm_score += context_score
new_hyp.context_state = new_context_state
if hyp.LODR_state is not None:
state_cost = hyp.LODR_state.forward_one_step(new_token)
# calculate the score of the latest token
current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score
assert current_ngram_score <= 0.0, (
state_cost.lm_score,
hyp.LODR_state.lm_score,
)
lm_score += LODR_lm_scale * current_ngram_score
new_hyp.LODR_state = state_cost
new_hyp.lm_score = lm_score
B.add(new_hyp) B.add(new_hyp)
B = B.topk(beam) B = B.topk(beam)
return B return B
def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id): def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id):
B.add( B.add(Hypothesis())
Hypothesis(
ys=[],
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
log_prob_blank=torch.zeros(1, dtype=torch.float32),
)
)
for j in range(encoder_out_lens): for j in range(encoder_out_lens):
log_probs, indexes = topk_values[j], topk_indexes[j] log_probs, indexes = topk_values[j], topk_indexes[j]
B = _step_worker(log_probs, indexes, B, beam, blank_id) B = _step_worker(log_probs, indexes, B, beam, blank_id)
@ -1763,11 +1825,11 @@ def ctc_prefix_beam_search(
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
blank_id: int = 0, blank_id: int = 0,
context_graph: Optional[ContextGraph] = None,
process_pool: Optional[Pool] = None, process_pool: Optional[Pool] = None,
return_nbest: Optional[bool] = False, return_nbest: Optional[bool] = False,
) -> Union[List[List[int]], List[HypothesisList]]: ) -> Union[List[List[int]], List[HypothesisList]]:
batch_size, num_frames, vocab_size = ctc_output.shape batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning # TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
topk_values = topk_values.cpu() topk_values = topk_values.cpu()
@ -1800,6 +1862,136 @@ def ctc_prefix_beam_search(
return [hyp.ys for hyp in best_hyps] return [hyp.ys for hyp in best_hyps]
def ctc_prefix_beam_search_shallow_fussion(
ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
LODR_lm: Optional[NgramLm] = None,
LODR_lm_scale: Optional[float] = 0,
LM: Optional[LmScorer] = None,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]:
batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
topk_values = topk_values.cpu()
topk_indexes = topk_indexes.cpu()
encoder_out_lens = encoder_out_lens.tolist()
device = ctc_output.device
lm_scale = 0
init_scores = None
init_states = None
if LM is not None:
lm_scale = LM.lm_scale
sos_id = getattr(LM, "sos_id", 1)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device)
init_scores, init_states = LM.score_token(sos_token, lens)
init_scores, init_states = init_scores.cpu(), (
init_states[0].cpu(),
init_states[1].cpu(),
)
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[],
log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32),
log_prob_blank=torch.zeros(1, dtype=torch.float32),
lm_score=torch.zeros(1, dtype=torch.float32),
state=init_states,
lm_log_probs=None if init_scores is None else init_scores.reshape(-1),
LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm),
context_state=None if context_graph is None else context_graph.root,
)
)
for j in range(num_frames):
for i in range(batch_size):
if j < encoder_out_lens[i]:
log_probs, indexes = topk_values[i][j], topk_indexes[i][j]
B[i] = _step_worker(
log_probs,
indexes,
B[i],
beam,
blank_id,
lm_scale=lm_scale,
LODR_lm_scale=LODR_lm_scale,
context_graph=context_graph,
)
if LM is None:
continue
# update lm_score
token_list = [] # a list of list
hs = []
cs = []
indexes = [] # (batch_idx, key)
for batch_idx, hyps in enumerate(B):
for hyp in hyps:
if hyp.lm_log_probs is None:
if LM.lm_type == "rnn":
token_list.append([hyp.ys[-1]])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append([sos_id] + hyp.ys[:])
indexes.append((batch_idx, hyp.key))
if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu())
assert scores.size(0) == len(indexes), (scores.size(0), len(indexes))
for i in range(scores.size(0)):
batch_idx, key = indexes[i]
B[batch_idx][key].lm_log_probs = scores[i]
if LM.lm_type == "rnn":
state = (
lm_states[0][:, i, :].unsqueeze(1),
lm_states[1][:, i, :].unsqueeze(1),
)
B[batch_idx][key].state = state
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
for hyps in B:
for hyp in hyps:
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
hyp.lm_score += context_score
hyp.context_state = new_context_state
best_hyps = [b.get_most_probable() for b in B]
return [hyp.ys for hyp in best_hyps]
def ctc_prefix_beam_search_attention_decoder_rescoring( def ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output: torch.Tensor, ctc_output: torch.Tensor,
attention_decoder: torch.nn.Module, attention_decoder: torch.nn.Module,

View File

@ -19,8 +19,10 @@
import argparse import argparse
import collections import collections
import json
import logging import logging
import os import os
import pathlib
import re import re
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
@ -178,6 +180,15 @@ class AttributeDict(dict):
return return
raise AttributeError(f"No such attribute '{key}'") raise AttributeError(f"No such attribute '{key}'")
def __str__(self, indent: int = 2):
tmp = {}
for k, v in self.items():
# PosixPath is ont JSON serializable
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
v = str(v)
tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True)
def encode_supervisions( def encode_supervisions(
supervisions: dict, supervisions: dict,