mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
154 lines
5.2 KiB
Python
154 lines
5.2 KiB
Python
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Dict
|
|
|
|
import k2
|
|
import torch
|
|
import copy
|
|
from utils import Hypothesis, HypothesisList
|
|
|
|
|
|
def shallow_fusion(
|
|
LG: k2.Fsa,
|
|
hyp: Hypothesis,
|
|
tokens: torch.Tensor,
|
|
log_probs: torch.Tensor,
|
|
vocab_size: int,
|
|
blank_log_prob: torch.Tensor,
|
|
) -> HypothesisList:
|
|
"""
|
|
Args:
|
|
LG:
|
|
An n-gram. It should be arc sorted, deterministic, and epsilon free.
|
|
It contains disambig IDs and back-off arcs.
|
|
hyp:
|
|
The current hypothesis.
|
|
tokens:
|
|
The possible tokens that will be expanded from the given `hyp`.
|
|
It is a 1-D tensor of dtype torch.int32.
|
|
log_probs:
|
|
It contains the acoustic log probabilities of each path that
|
|
is extended from `hyp.ys` with `tokens`.
|
|
log_probs.shape == tokens.shape.
|
|
vocab_size:
|
|
Vocabulary size, including the blank symbol. We assume that
|
|
token IDs >= vocab_size are disambig IDs (including the backoff
|
|
symbol #0).
|
|
blank_log_prob:
|
|
The log_prob for the blank token at this frame. It is from
|
|
the output of the joiner.
|
|
Returns:
|
|
Return new hypotheses by extending the given `hyp` with tokens in the
|
|
given `tokens`.
|
|
"""
|
|
|
|
row_splits = LG.arcs.row_splits(1)
|
|
arcs = LG.arcs.values()
|
|
|
|
state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores)
|
|
|
|
current_states = list(state_and_scores.keys())
|
|
|
|
# Process out-going arcs with label equal to disambig tokens or #0
|
|
while len(current_states) > 0:
|
|
s = current_states.pop()
|
|
labels_begin = row_splits[s]
|
|
labels_end = row_splits[s + 1]
|
|
labels = LG.labels[labels_begin:labels_end].contiguous()
|
|
|
|
for i in reversed(range(labels.numel())):
|
|
lab = labels[i]
|
|
if lab == -1:
|
|
# Note: When sorting arcs, k2 treats arc labels as
|
|
# unsigned types
|
|
continue
|
|
|
|
if lab < vocab_size:
|
|
# Since LG is arc sorted, we can exit
|
|
# the for loop as soon as we have a label
|
|
# with ID less than vocab_size
|
|
break
|
|
|
|
# This is a diambig token or #0
|
|
idx = labels_begin + i
|
|
next_state = arcs[idx][1].item()
|
|
score = LG.scores[idx] + state_and_scores[s]
|
|
if next_state not in state_and_scores:
|
|
state_and_scores[next_state] = score
|
|
current_states.append(next_state)
|
|
else:
|
|
state_and_scores[next_state] = max(
|
|
score, state_and_scores[next_state]
|
|
)
|
|
|
|
current_states = list(state_and_scores.keys())
|
|
ans = HypothesisList()
|
|
|
|
device = log_probs.device
|
|
for s in current_states:
|
|
labels_begin = row_splits[s]
|
|
labels_end = row_splits[s + 1]
|
|
labels = LG.labels[labels_begin:labels_end].contiguous()
|
|
|
|
if labels[-1] == -1:
|
|
labels = labels[:-1]
|
|
|
|
if s != 0:
|
|
# We add a backoff arc to the start state. Otherwise,
|
|
# all activate state may die due to out-of-Vocabulary word.
|
|
new_hyp = Hypothesis(
|
|
ys=hyp.ys[:],
|
|
log_prob=hyp.log_prob + blank_log_prob,
|
|
ngram_state_and_scores={
|
|
# -20 is the cost on the backoff arc to the start state.
|
|
# As LG.scores.min() is about -16.6, we choose -20 here.
|
|
# You may need to tune this value.
|
|
0: torch.full((1,), -20, dtype=torch.float32, device=device)
|
|
},
|
|
)
|
|
ans.add(new_hyp)
|
|
|
|
pos = torch.searchsorted(labels, tokens)
|
|
for i in range(pos.numel()):
|
|
if tokens[i] == 0:
|
|
# blank ID
|
|
new_hyp = Hypothesis(
|
|
ys=hyp.ys[:],
|
|
log_prob=hyp.log_prob + log_probs[i],
|
|
ngram_state_and_scores=hyp.ngram_state_and_scores,
|
|
)
|
|
ans.add(new_hyp)
|
|
continue
|
|
elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]:
|
|
# No out-going arcs from this state has labels
|
|
# equal to tokens[i]
|
|
continue
|
|
|
|
# Found one arc
|
|
|
|
idx = labels_begin + pos[i]
|
|
next_state = arcs[idx][1].item()
|
|
score = LG.scores[idx] + state_and_scores[s]
|
|
new_hyp = Hypothesis(
|
|
ys=hyp.ys + [tokens[i].item()],
|
|
log_prob=hyp.log_prob + log_probs[i],
|
|
ngram_state_and_scores={next_state: score},
|
|
)
|
|
ans.add(new_hyp)
|
|
|
|
return ans
|