2022-02-10 20:28:59 +08:00

110 lines
3.5 KiB
Python

# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import k2
import torch
import copy
def shallow_fusion(
LG: k2.Fsa,
token: int,
state_and_scores: Dict[int, torch.Tensor],
vocab_size: int,
) -> Dict[int, torch.Tensor]:
"""
Args:
LG:
An n-gram. It should be arc sorted, deterministic, and epsilon free.
token:
The input token ID.
state_and_scores:
The keys contain the current state we are in and the
values are the LM log_prob for reaching the corresponding
states from the start state.
vocab_size:
Vocabulary size, including the blank symbol. We assume that
token IDs >= vocab_size are disambig IDs (including the backoff
symbol #0).
Returns:
Return a new state_and_scores.
"""
row_splits = LG.arcs.row_splits(1)
arcs = LG.arcs.values()
state_and_scores = copy.deepcopy(state_and_scores)
current_states = list(state_and_scores.keys())
# Process out-going arcs with label being disambig tokens and #0
while len(current_states) > 0:
s = current_states.pop()
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
for i in reversed(range(labels.numel())):
lab = labels[i]
if lab == -1:
# Note: When sorting arcs, k2 treats arc labels as
# unsigned types
continue
if lab < vocab_size:
# Since LG is arc sorted, we can exit
# the for loop as soon as we have a label
# with ID less than vocab_size
break
# This is a diambig token or #0
idx = labels_begin + i
next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
if next_state not in state_and_scores:
state_and_scores[next_state] = score
current_states.append(next_state)
else:
state_and_scores[next_state] = max(
score, state_and_scores[next_state]
)
current_states = list(state_and_scores.keys())
ans = dict()
for s in current_states:
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
if labels[-1] == -1:
labels = labels[:-1]
pos = torch.searchsorted(labels, token)
if pos >= labels.numel() or labels[pos] != token:
continue
idx = labels_begin + pos
next_state = arcs[idx][1].item()
score = LG.scores[idx] + state_and_scores[s]
if next_state not in ans:
ans[next_state] = score
else:
ans[next_state] = max(score, ans[next_state])
return ans